From e275d8fa337cdc19b7488b93906dcf13e9b52172 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 21 Aug 2024 22:59:29 -0700 Subject: [PATCH 01/54] RS: Compose multiple MMA tiles into a single swizzle pattern (#2710) Stacked on https://github.com/NVIDIA/Fuser/pull/2730 This PR is an incremental step towards supporting scheduling multiple tiles of MMA. The main thing that this PR does is to add a new test `HopperRS.FullSwizzle` and adjust our system to make this test work. I tried my best to document `HopperRS.FullSwizzle`, so I would recommend starting the review from that test. --- csrc/compute_at_map.cpp | 6 +- csrc/compute_at_map.h | 3 +- csrc/device_lower/lower2device.cpp | 3 + csrc/device_lower/pass/allocation.cpp | 80 +++++++++++ csrc/device_lower/pass/index.cpp | 10 +- csrc/id_model/indexing.cpp | 24 ++-- csrc/index_compute.cpp | 2 + tests/cpp/test_mma.cpp | 198 ++++++++++++++++++++++++++ 8 files changed, 312 insertions(+), 14 deletions(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 4d7b02f3cad..0176ec7563a 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -1514,9 +1514,9 @@ const DisjointSets& ComputeAtMap::getIdSets( NVF_ERROR(false, "Error with mapping mode provided."); } -bool ComputeAtMap::idExistsInMap(IterDomain* id) const { - return getIdSets(IdMappingMode::EXACT).disjointSetMap().find(id) != - getIdSets(IdMappingMode::EXACT).disjointSetMap().end(); +bool ComputeAtMap::idExistsInMap(IterDomain* id, IdMappingMode mode) const { + return getIdSets(mode).disjointSetMap().find(id) != + getIdSets(mode).disjointSetMap().end(); } VectorOfUniqueEntries>> diff --git a/csrc/compute_at_map.h b/csrc/compute_at_map.h index e46930b0740..91f0450955f 100644 --- a/csrc/compute_at_map.h +++ b/csrc/compute_at_map.h @@ -271,7 +271,8 @@ class ComputeAtMap { // Returns if the ID actually has a disjoint set meaning it has been processed // in the creation of the compute at map. - bool idExistsInMap(IterDomain* id) const; + bool idExistsInMap(IterDomain* id, IdMappingMode mode = IdMappingMode::EXACT) + const; //! Returns the pre-allocated index variable integer used in //! the ForLoop corresponding to the given IterDomain. diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 856ac75fca2..03d5834c0e8 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -339,6 +339,9 @@ bool requiresIdModel(Fusion* fusion) { return true; } } + if (expr->isA()) { + return true; + } } // If a tensor does not have a nice root->logical/allocation->loop // linear transformation history, use IdModel. diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index c7ad8fe1433..d78d909c8dd 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -9,8 +9,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -174,6 +176,84 @@ class AllocationInserter : public kir::ExprMutator { info.allocation_domains = std::make_unique>(); + // TODO: Today, we always allocate loop domain, even if the allocation + // domain is explicitly set. This is clearly not the right thing to do, + // and we should fix this in the future. However, today, we still don't + // have a clear design of how to allocate tensors with explicit allocation + // domain. This problem is very difficult to solve, and there are many + // things to consider. For example, if the allocation domain contains a + // subset of inlined loop IDs, we should not allocate the inlined IDs. + // But what if the opposite is true? What if the allocation domain + // does not contain all inlined IDs? Is this considered an error, or + // a valid case that we need to infer which to allocate from the loop + // domain? We need to think about this carefully and come up with a + // clear design. For now, we just allocate the loop domain for historical + // reasons for all cases except for the Hopper MMA output tensor. + // + // Hopper MMA output tensor is a special case because the loop domain + // is scheduled in a way that the entire tile is parallelized by MMA, and + // The TIDx parallelization is a new broadcast dimension that is not + // connected to any other IterDomains. This way of scheduling effectively + // makes the loop domain 128x larger than the allocation domain, because + // the allocation domain is sharded on threads but the loop domain is not. + if ((info.buffer->definition()->isA() && + isHopper(info.buffer->definition()->as()->macro()))) { + const IdModel& id_model = GpuLower::current()->idModel(); + + std::unordered_set exclude_ca_ids; + for (auto i : c10::irange(info.alloc_pos)) { + auto ca_id = info.buffer->axis(i); + if (!ir_utils::isMemorySharedAcross( + info.buffer->getMemoryType(), ca_id->getParallelType())) { + exclude_ca_ids.insert(ca_id); + } + } + + const std::vector& domain_to_alloc = + info.buffer->hasAllocation() ? info.buffer->getAllocationDomain() + : info.buffer->getLoopDomain(); + + for (auto id : domain_to_alloc) { + if (exclude_ca_ids.find(id) == exclude_ca_ids.end()) { + // Don't use reduction/stride/broadcast/device axis in the + // allocation computation + if (id->isReduction() || id->isStride() || id->isBroadcast() || + id->isDeviceDim()) { + continue; + } + if (ir_utils::isMemoryPartitionedAcross( + info.buffer->getMemoryType(), id->getParallelType())) { + continue; + } + info.allocation_domains->push_back(id); + + // Loop promotion may affect allocations. Promotions of intermediate + // domains may not be defined correctly. Only consider loop domains + // for now. + bool is_loop = std::find( + info.buffer->getLoopDomain().begin(), + info.buffer->getLoopDomain().end(), + id) != info.buffer->getLoopDomain().end(); + if (is_loop) { + id = indexing_utils::getLoopPromotion(id, id_model); + } + + alloc_dims.push_back(id->extent()); + } else { + exclude_ca_ids.erase(id); + } + } + NVF_ERROR( + exclude_ca_ids.empty(), + "The non-allocating compute-at IDs are not found in the allocation domain. ", + "It is unclear how to allocate the tensor: ", + info.buffer->toString(), + " allocation domain: ", + ir_utils::toString(info.buffer->getAllocationDomain())); + + return alloc_dims; + } + for (const auto axis_i : c10::irange(info.buffer->nDims())) { const auto local_id = info.buffer->axis(axis_i); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 0e0cd87d316..19b8aea782a 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1481,7 +1481,7 @@ static DataType getMmaInputBType(MmaMacro macro) { static inline DataType getMmaOutType(TensorView* mma_out) { int64_t size = 1; - for (auto id : mma_out->getLoopDomain()) { + for (auto id : mma_out->getAllocationDomain()) { if (id->isMma() && !id->isReduction()) { size *= id->extent()->evaluate().as(); } @@ -1624,7 +1624,13 @@ void IndexLowering::handle(const MmaOp* mma) { // smem. auto tv = mma->inB()->as(); auto swizzle = getSwizzleMode(tv); - auto base_addr = IrBuilder::baseAddressExpr(tv); + // Because the entire tile is parallelized on MMA, which are trivial + // loops and always have zero loop variables, the result of lowerSrcIndex + // will be the address of the first element of the tile, which happens to + // be the information we need to provide to the hardware. + auto base_addr = lowerSrcIndex(tv, mma->out(), {}, true) + ->as() + ->index(); int64_t leading_bytes = core_matrix_outer_size * getBytesFromSwizzle(swizzle); // swizzle period in bytes int64_t inner_size = diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index 842e398cbc1..e5d2f2a5cf1 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -182,8 +182,11 @@ class AllocationDomainSetup : private kir::IrVisitor { // reasonably well. bool use_set_allocation_domain = false; if (tv->hasAllocation()) { - // Honor the allocation domain if the tensor is global memory - if (tv->getMemoryType() == MemoryType::Global) { + // Honor the allocation domain if the tensor is global or Hopper MMA's + // output + if (tv->getMemoryType() == MemoryType::Global || + (tv->definition()->isA() && + isHopper(tv->definition()->as()->macro()))) { use_set_allocation_domain = true; } else if (tv->getMemoryType() == MemoryType::Shared) { // If it's a shared memory tensor, the set domain is likely @@ -217,25 +220,29 @@ class AllocationDomainSetup : private kir::IrVisitor { allocation_domains = tv->getAllocationDomain(); contiguity = tv->domain()->contiguity(); } else { - std::unordered_set exclude_ids; + std::unordered_set exclude_ca_ids; for (auto i : c10::irange(tv->getComputeAtPosition())) { auto ca_id = tv->axis(i); if (!ir_utils::isMemorySharedAcross( tv->getMemoryType(), ca_id->getParallelType())) { - exclude_ids.insert(ca_id); + exclude_ca_ids.insert(ca_id); } } for (auto i : c10::irange(tv->getAllocationDomain().size())) { auto id = tv->getAllocationDomain()[i]; - if (exclude_ids.find(id) == exclude_ids.end()) { + if (exclude_ca_ids.find(id) == exclude_ca_ids.end()) { + if (ir_utils::isMemoryPartitionedAcross( + tv->getMemoryType(), id->getParallelType())) { + continue; + } allocation_domains.push_back(id); contiguity.push_back(tv->domain()->contiguity()[i]); } else { - exclude_ids.erase(id); + exclude_ca_ids.erase(id); } } NVF_ERROR( - exclude_ids.empty(), + exclude_ca_ids.empty(), "The non-allocating compute-at IDs are not found in the allocation domain. ", "It is unclear how to allocate the tensor: ", tv->toString(), @@ -1029,7 +1036,8 @@ std::unordered_map TensorIndexer::getIndexReplacementMap( // of the domain, for predication, so the replacement is not // always done with zero. if (loop_id->getParallelType() == ParallelType::Vectorize || - loop_id->getParallelType() == ParallelType::Bulk) { + loop_id->getParallelType() == ParallelType::Bulk || + loop_id->getParallelType() == ParallelType::Mma) { replacement_index = loop_id->fusion()->zeroVal(); } else { ForLoop* for_loop = indexing_utils::getForLoop( diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 1b246ee625f..f18e25b41e2 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2100,6 +2100,8 @@ kir::TensorIndex* Index::getProducerIndex( Val* index = nullptr; if (!lower_utils::hasRootToLoopLinearTransformations(producer) || + (consumer->definition()->isA() && + isHopper(consumer->definition()->as()->macro())) || (isIdModelOptionEnabled(IdModelEnableOption::ProducerIndex) && GpuLower::current()->isTensorIndexerEnabled())) { index = GpuLower::current()->tensorIndexer().getLinearIndex( diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index af11f2c4beb..b7dab16e626 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -412,6 +412,204 @@ TEST_P(HopperRS, SingleTile) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +// There are only three possible swizzle modes for smem operands of Hopper MMA: +// 32 byte, 64 byte, and 128 byte. Depending on the layout and the macro, the +// inner size may be smaller than the swizzle size. For example, if the macro is +// M64_N8_K16, and the layout is TT, then K is the inner dim, so the inner size +// is 16 items, that is, 32 bytes. If the swizzle mode is 128 byte, then the +// inner size is only 1/4 of the swizzle size. In the SingleTile test, we will +// just pad the inner dim to match the swizzle size, which is a 4x waste of smem +// space. In this test, instead of padding the inner dim, we will use four tiles +// to cover the entire swizzle size, so there is no waste of smem space. Note +// that composing four tiles to form a single swizzle pattern means that the +// memory layout of these four tiles will be interleaved with each other. The +// kernel we are getting is like this: +// +// For TN layout where the inner dimension is a reduction: +// load operand B from gmem to smem; +// accumulator = 0; +// for i in tiles: +// load operand A from gmem to register; +// accumulator += A * B; +// store accumulator to gmem; +// +// For TT layout where the inner dimension is not a reduction: +// load operand B from gmem to smem; +// for i in tiles: +// load operand A from gmem to register; +// accumulator = 0; +// accumulator += A * B; +// store accumulator to gmem; +TEST_P(HopperRS, FullSwizzle) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto swizzle_size = getBytesFromSwizzle(swizzle_b) / dataTypeSize(dtype); + auto inner_size = layout == MmaLayout::TT ? getN(macro) : getK(macro); + + if (swizzle_size / inner_size <= 1) { + GTEST_SKIP() + << "Already tested in SingleTile, not interested in testing it again"; + } + + if (swizzle_size % inner_size != 0) { + GTEST_SKIP() + << "We will be using swizzle size as CTA tile size, so it must be divisible"; + } + + const auto k_axis = layout == MmaLayout::TT ? 1 : 2; + + auto shapes = layout == MmaLayout::TT + ? matmulAtInputShape3DHopperRS( + getM(macro), swizzle_size, getK(macro), layout) + : matmulAtInputShape3DHopperRS( + getM(macro), getN(macro), swizzle_size, layout); + + auto tv0 = makeConcreteTensor(shapes.first, dtype); + auto tv1 = makeConcreteTensor(shapes.second, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Just doing a gmem->register copy + tv0 = set(tv0); + // Just doing a gmem->smem copy + tv1 = set(tv1); + tv1->setMemoryType(MemoryType::Shared); + + auto tv2 = fusedMultiplySum(tv0, tv1, {k_axis}); + + fusion.addOutput(tv2); + + auto mma_ops = ir_utils::getOpsOfType(&fusion); + NVF_CHECK( + 1 == mma_ops.size(), + "Invalid number of MmaOp instances in fusion definition, expected 1, got ", + mma_ops.size()); + mma_ops.front()->setMacro(macro); + + auto tv2c = tv2->cacheBefore(); + + moveInnerBroadcastLeft(tv0); // n, m, k + + // Split the inner dimension by the inner size, and reorder the outer + // of the split to dim 0. + if (layout == MmaLayout::TN) { + // inner is K, and K has multiple tiles + tv0->split(2, inner_size); + tv0->reorder({{-2, 0}}); + // ko, n, m, ki + } else { + // inner is N, and N has multiple tiles + tv0->split(0, inner_size); + // no, ni, m, k + } + + // Now, the inner 2 dimensions are a single MMA tile + tv0->applyMmaSwizzle(MmaOperand::A); + + tv0->merge(2); + tv0->merge(2); + tv0->axis(2)->parallelize(ParallelType::TIDx); + + // Just schedule tv1 the same way as in SingleTile. Note that although + // the schedule are the same, the memory layout of tv1 is different. + // For example, assume that the inner size is 16, and the swizzle size is 64. + // For the case of SingleTile, the input tensor size will be 16, so the inner + // dimension will be split as: + // 1, 64 = split(16, 64) + // For the case of FullSwizzle, the input tensor size will be 64, so the inner + // dimension will be split as: + // 1, 64 = split(64, 64) + tv1->applyMmaSwizzle(swizzle_b); + naivelyParallelize(tv1); + + // Split the inner dimension by the inner size, and reorder the outer + // of the split to dim 0. + tv2c->split(-1, inner_size); + tv2c->reorder({{-2, 0}}); + + // Now, the inner 3 dimensions are a single MMA tile. + // In the loop domain, just parallelize all of them as Mma. + tv2c->axis(1)->parallelize(ParallelType::Mma); + tv2c->axis(2)->parallelize(ParallelType::Mma); + tv2c->axis(3)->parallelize(ParallelType::Mma); + + if (layout == MmaLayout::TT) { + // [M, K, N] -> [M, N, K] + tv2c->reorder({{-1, -2}}); + } + + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv2c->getLoopDomain()); + tv2c->setAllocationDomain(s.as(), true); + } + + // Create a dummy broadcasting IterDomain to denote that this instruction + // is a collective operation over 128 threads. This is a newly created + // broadcasting IterDomain and is not connected to other IterDomains in the + // TensorDomain. The reason for doing so is because the MMA instruction is + // really a collective operation over 128 threads, and by definition there is + // no per-thread assignment like "this thread works on this part of the + // tensor". It is actually all threads working on all data. For this reason, + // the threadIdx.x should not appear anywhere in the index of the tensor. + tv2c->broadcast(1, 128); + tv2c->axis(1)->parallelize(ParallelType::TIDx); + + if (layout == MmaLayout::TT) { + // If TN, then the inner dim is K, which is also the reduction dimension. + // For this case, K does not exist in tv2, so nothing to split. + // If TT, then the inner dim is N, which is not the reduction dimension. + // For this case, N exists in tv2, so we need to split it. + tv2->split(-1, inner_size); + tv2->reorder({{-2, 0}}); + } + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv2->getLoopDomain()); + tv2->setLoopDomain(s.as()); + } + + // Inline gmem->register load into the MMA expression at 1. + // The shared loop is the loop over multiple tiles. + tv0->inlineAt(1); + // If TN, then the dim with multiple tiles is K, then the shared loop is a + // reduction loop. This reduction loop does not exist in the register->gmem + // store, so nothing to inline. + // If TT, then the dim with multiple tiles is N, then the shared loop is not a + // reduction loop. This shared loop exists in the register->gmem store, so we + // will inline the MMA expression into the register->gmem store. + if (layout == MmaLayout::TT) { + tv2c->inlineAt(1); + } + + auto inputs = + (layout == MmaLayout::TT ? matmulAtInput3DHopperRS( + getM(macro), + swizzle_size, + getK(macro), + layout, + data_type_to_aten(dtype)) + : matmulAtInput3DHopperRS( + getM(macro), + getN(macro), + swizzle_size, + layout, + data_type_to_aten(dtype))); + + FusionExecutor fe; + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams); + + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + + auto tref = atMatmul( + inputs.first.squeeze().to(at::kFloat), + inputs.second.squeeze().to(at::kFloat), + layout); + EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); +} + TEST_P(HopperRS, SingleTileWithTMALoad) { Fusion fusion; FusionGuard fg(&fusion); From dc7aec688b4fad90516ee46f087b9086f28c7871 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 22 Aug 2024 09:08:46 -0700 Subject: [PATCH 02/54] Add more comments so I remember how to parse results. (#2830) --- tools/benchmark_thunder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/benchmark_thunder.py b/tools/benchmark_thunder.py index 530d8a3948c..8552e0677b8 100644 --- a/tools/benchmark_thunder.py +++ b/tools/benchmark_thunder.py @@ -70,7 +70,10 @@ def run_settings(self, settings: Iterable[str]) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Runs Thunder benchmarks with multiple settings." + description="Runs Thunder benchmarks with multiple settings. " + "It stores benchmark results to the specified storage path, which " + "can be compared by running `pytest-benchmark --storage " + "compare --group-by name`." ) parser.add_argument( "settings", From 7d6f342fa399c78bb868c02bca8e1979c2539f91 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:39:15 -0700 Subject: [PATCH 03/54] Print StructHandle in toString(PolymorphicValue) [second attempt] (#2834) I merged #2810 too hastily and wound up breaking CI. This was reverted in #2829. This PR is a second attempt to fix printing of `StructHandle`. --- csrc/evaluator_common.cpp | 4 +-- csrc/polymorphic_value.cpp | 38 +++++++++++++++++++++++++++- csrc/polymorphic_value.h | 15 +---------- tests/cpp/test_polymorphic_value.cpp | 20 +++++++++++++++ 4 files changed, 60 insertions(+), 17 deletions(-) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 9384bffe7c3..fa8a874dd72 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -222,8 +222,8 @@ void PrecomputedValues::print() const { debug() << "Precomputed Values:\n"; for (auto i : c10::irange(symbols_.size())) { if (defined_[i]) { - debug() << symbols_[i]->toInlineString() << " = " << values_[i] - << std::endl; + debug() << symbols_[i]->toInlineString() << " = " + << PolymorphicValue_functions::toString(values_[i]) << std::endl; } } } diff --git a/csrc/polymorphic_value.cpp b/csrc/polymorphic_value.cpp index 3ae0a9d731e..e2c838ef6e8 100644 --- a/csrc/polymorphic_value.cpp +++ b/csrc/polymorphic_value.cpp @@ -5,10 +5,11 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on - #include #include +#include + namespace nvfuser { bool StructHandle::operator==(const StructHandle& other) const { @@ -38,4 +39,39 @@ bool StructHandle::operator==(const StructHandle& other) const { return true; } +namespace PolymorphicValue_functions { + +std::string toString(const PolymorphicValue& v) { + std::stringstream ss; + if (v.is()) { + const auto& t = v.as(); + ss << "Tensor(sizes=" << t.sizes() << ", " + << "stride=" << t.strides() << ", dtype=" << t.dtype() + << ", device=" << t.device() << ", data_ptr=" << t.data_ptr() << ")"; + } else if (v.is()) { + ss << "std::monostate"; + } else if (v.is()) { + const StructHandle& hdl = v.as(); + StructType type = (v->*&StructHandle::type)(); + ss << "StructHandle<" << type.name << ">{"; + bool first = true; + for (size_t i : c10::irange(type.fields.size())) { + if (first) { + first = false; + } else { + ss << ", "; + } + const std::string& fieldname = type.fields.at(i).name; + ss << fieldname << "="; + ss << toString(hdl->*(fieldname)); + } + ss << "}"; + } else { + ss << v; + } + return ss.str(); +} + +} // namespace PolymorphicValue_functions + } // namespace nvfuser diff --git a/csrc/polymorphic_value.h b/csrc/polymorphic_value.h index 5775feb1173..645b35fda53 100644 --- a/csrc/polymorphic_value.h +++ b/csrc/polymorphic_value.h @@ -221,20 +221,7 @@ using PolymorphicValue = dynamic_type::DynamicType< namespace PolymorphicValue_functions { -inline std::string toString(const PolymorphicValue& v) { - std::stringstream ss; - if (v.is()) { - const auto& t = v.as(); - ss << "Tensor(sizes=" << t.sizes() << ", " - << "stride=" << t.strides() << ", dtype=" << t.dtype() - << ", device=" << t.device() << ", data_ptr=" << t.data_ptr() << ")"; - } else if (v.is()) { - ss << "std::monostate"; - } else { - ss << v; - } - return ss.str(); -} +NVF_API std::string toString(const PolymorphicValue& v); template inline bool isNan(const T& a) { diff --git a/tests/cpp/test_polymorphic_value.cpp b/tests/cpp/test_polymorphic_value.cpp index 57e5403348f..b471189a33c 100644 --- a/tests/cpp/test_polymorphic_value.cpp +++ b/tests/cpp/test_polymorphic_value.cpp @@ -43,6 +43,19 @@ TEST_F(PolymorphicValueTest, OpaqueEquality) { EXPECT_NE(c, a2); } +TEST_F(PolymorphicValueTest, OpaquePrint) { + Opaque a{DataType::Int}; + struct A { + int64_t x; + double y; + }; + Opaque a1(A{1, 2.0}); + EXPECT_THAT( + PolymorphicValue_functions::toString(a), testing::StartsWith("Opaque<")); + EXPECT_THAT( + PolymorphicValue_functions::toString(a1), testing::StartsWith("Opaque<")); +} + TEST_F(PolymorphicValueTest, Struct) { struct A : public Struct { int64_t x; @@ -107,6 +120,10 @@ TEST_F(PolymorphicValueTest, Struct) { EXPECT_EQ(*type.fields.at(1).type, DataType::Double); EXPECT_FALSE(type.fields.at(1).used_in_kernel); + EXPECT_EQ( + PolymorphicValue_functions::toString(a), + "StructHandle{x=2788, y=2.71828}"); + { // intentionally create a new scope and define another struct with the same // name to make sure the previous struct is not accessible @@ -138,6 +155,9 @@ TEST_F(PolymorphicValueTest, Struct) { EXPECT_EQ(b->*"y", PolymorphicValue(3.1415926)); EXPECT_EQ(type, (b->*&StructHandle::type)()); + EXPECT_EQ( + PolymorphicValue_functions::toString(b), + "StructHandle{x=299792458, y=3.14159}"); } } From 2becfb358deb8451db9cdf494843bcc755648500 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:35:27 -0700 Subject: [PATCH 04/54] Bind TensorMetaData to PrecomputedValues (#2812) This was found when implementing #2714. We currently do not bind TensorMetaData for input tensors to PrecomputedValues. This means we cannot evaluate expressions that contain them, which can lead to errors. This PR binds these metadata structs, which I think is the expected behavior. --- csrc/evaluator_common.cpp | 31 +++++++++++++++++++++++------ csrc/ir/nodes.cpp | 3 +++ csrc/utils.h | 26 ++++++++++++++++++++++++ tests/cpp/test_evaluator.cpp | 38 ++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 6 deletions(-) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index fa8a874dd72..1042931914e 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -56,10 +56,16 @@ std::vector makeSortedEvaluationList(std::vector input) { to_sort.pop_back(); } else { bool ready_to_pop = true; - for (auto producer : getImmediateProducers(top_val)) { - if (!visited.count(producer)) { - ready_to_pop = false; - to_sort.push_back(producer); + // Struct types must be bound directly. This is because it would + // otherwise require us to compute T0 just to compute GetMetaData(T0), + // for example. We skip computing producers of Structs here in order to + // avoid computing the TensorViews themselves. + if (!isStructType(top_val->dtype())) { + for (auto producer : getImmediateProducers(top_val)) { + if (!visited.count(producer)) { + ready_to_pop = false; + to_sort.push_back(producer); + } } } if (ready_to_pop) { @@ -105,7 +111,12 @@ std::vector collectRuntimeUsedValues(Fusion* fusion) { } } for (auto inp : fusion->inputs()) { - if (!inp->isA()) { + if (auto* tv = dynamic_cast(inp)) { + // For TensorView inputs, do not bind the TV itself. Only bind its + // TensorMetaData + Val* metadata = fusion->metadataOf(tv); + ret.push_back(metadata); + } else { ret.push_back(inp); } } @@ -198,7 +209,9 @@ void PrecomputedValues::initializeValueList( // Fill in constants and assign evaluator indices for (const auto i : c10::irange(num_of_values_)) { // Use an expression evaluator to test if value is const - if (sorted_value_list[i]->isConstScalar()) { + // Structs must be bound directly + if (!isStructType(sorted_value_list[i]->dtype()) && + sorted_value_list[i]->isConstScalar()) { is_constant_[i] = true; values_[i] = sorted_value_list[i]->evaluate(); } @@ -562,6 +575,9 @@ void NaiveValueMachine::runUnaryOp(int index) { case UnaryOpType::BitwiseNot: dest = ~src; break; + case UnaryOpType::Reciprocal: + dest = 1.0 / src; + break; case UnaryOpType::Signbit: dest = signbit(src); break; @@ -659,6 +675,9 @@ void NaiveValueMachine::runBinaryOp(int index) { case BinaryOpType::Fmod: dest = fmod(lhs, rhs); break; + case BinaryOpType::Pow: + dest = pow(lhs, rhs); + break; default: NVF_CHECK(false, "Unexpected operator type ", bop_type_[index]); } diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index fd7b0fa0849..6ce16a55113 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -658,6 +658,9 @@ std::vector BinaryOp::evaluate( case BinaryOpType::Complex: return {at::complex(lhs.as(), rhs.as())}; break; + case BinaryOpType::Pow: + return {pow(lhs, rhs)}; + break; default: NVF_CHECK( false, diff --git a/csrc/utils.h b/csrc/utils.h index 321214c9857..43e71b2be83 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -568,4 +568,30 @@ inline int64_t wrapDim(int64_t dim, int64_t ndim) { return dim; } +// This is the same as the pow utility included in runtime/helpers.cu. It is +// included here to facilitate matching host-side computation. +template +T pow(T a, T b) { + if (b < 0) { + if (a == 1) { + return 1; + } else if (a == -1) { + auto negative = (-b) % static_cast(2); + return negative ? -1 : 1; + } else { + return 0; + } + } else { + T result = 1; + while (b) { + if (b & 1) { + result *= a; + } + b /= 2; + a *= a; + } + return result; + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_evaluator.cpp b/tests/cpp/test_evaluator.cpp index b4c815e52e6..91df57b8269 100644 --- a/tests/cpp/test_evaluator.cpp +++ b/tests/cpp/test_evaluator.cpp @@ -776,4 +776,42 @@ TEST_F(ExprEvalTest, BinaryOpFmod) { EXPECT_EQ(evaluator.evaluate(out5).as(), std::fmod(3, -0.8)); } +// Test that we properly bind tensor metadata in PrecomputedValues so that we +// can access it from an ExpressionEvaluator +TEST_F(ExprEvalTest, TensorMetadataPrecomputedValues) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto* tv1 = set(tv0); + fusion.addOutput(tv1); + + PrecomputedValues pv(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({3, 4}, options); + auto args = KernelArgumentHolder::createKernelArgumentHolder({t0}); + + // now compute metadata of tv0 + auto metadata = fusion.metadataOf(tv0); + ASSERT_TRUE(metadata != nullptr); + EXPECT_EQ(metadata->dtype(), metaDataTypeOf(tv0)); + auto logical_size = IrBuilder::getAttrExpr(metadata, "logical_size"); + auto logical_size_0 = IrBuilder::getItemExpr(logical_size, fusion.zeroVal()); + auto logical_size_1 = IrBuilder::getItemExpr(logical_size, fusion.oneVal()); + + pv.bindInputs(args); + pv.evaluate(); + + ExpressionEvaluator evaluator; + evaluator.bindPrecomputedValues(&pv); + + EXPECT_TRUE(evaluator.evaluate(metadata).hasValue()); + + checkIntValue(evaluator, logical_size_0, 3); + checkIntValue(evaluator, logical_size_1, 4); +} + } // namespace nvfuser From afba9938b6cdc2e537b948cb79fe54d4e190d262 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:14:39 -0400 Subject: [PATCH 05/54] remove no longer required runtime check of persistent batch size in innerOuter scheduler (#2776) After adding `getPersistentBufferStorageParams()`, the buffer size is correctly set and no longer require additional checks of whether we can get a "reasonable" persistent batch size. This PR changes `getOptionalInnerOuterPersistentBufferBatches` to `getBufferBatchSizeAndThreadsPerBlock`, as suggested by the name change, `optional` is deleted indicating it always return a pair of {BufferBatchSize, ThreadsPerBlock} --- csrc/scheduler/normalization_inner_outer.cpp | 94 ++++---------------- 1 file changed, 15 insertions(+), 79 deletions(-) diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index bbf57718020..a29cb3af26a 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -490,21 +490,12 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( // registers, this function will return that batch value. Otherwise, it will // return nullopt except when ignore_register_size_limit is true where it will // return whatever the batch value is. -// This exception is needed because the register usage in canScheduleRuntime is -// based on std::min(project_buffer, not_project_buffer). However, in -// getPersistentHeuristics() we enforce project_buffer to input if dtype=float -// and feature size <=14K. It leads to register spills but still faster than -// unprojected version due to the reuse of a input para in this grid persistent -// kernel. This is a tmp solution before we have a new persistent heuristics, -// where the projection should not soley based on size of buffers. -std::pair, int64_t> -getOptionalInnerOuterPersistentBufferBatches( +std::pair getBufferBatchSizeAndThreadsPerBlock( const int64_t inner_dim_numel, const int64_t outer_dim_numel, const int64_t persistent_buffer_size, const int64_t vectorize_factor, - const int64_t warp_size, - const bool ignore_register_size_limit) { + const int64_t warp_size) { // if inner_dim_numel <= 1024, we are doing multiple reductions per block // with a constant batch size of 1 if vectorized. See Step 5 of // innerOuterPersistentHeuristic. Although batch size is 1, each thread also @@ -534,13 +525,13 @@ getOptionalInnerOuterPersistentBufferBatches( } return 1l; }; - //! Each thread can use a maximum of 255 registers, and assume 40 of them are - //! reserved for indexing and other purposes. So, each thread can use up to - //! 215 registers for persistent buffer. Calculate number of buffer batches - //! using these 215 registers. total_buffer_bytes is the total size of - //! persistent buffers in bytes. reduction_elements is the number of elements - //! in the reduction domain. vectorization_factor is the vectorization factor - //! of inputs and outputs. + // Each thread can use a maximum of 255 registers, and assume 40 of them are + // reserved for indexing and other purposes. So, each thread can use up to + // 215 registers for persistent buffer. Calculate number of buffer batches + // using these 215 registers. total_buffer_bytes is the total size of + // persistent buffers in bytes. reduction_elements is the number of elements + // in the reduction domain. vectorization_factor is the vectorization factor + // of inputs and outputs. auto getMaximumInnerOuterPersistentBufferBatch = [&]() -> int64_t { int64_t register_per_batch = ceilDiv( persistent_buffer_size / inner_dim_numel * vectorize_factor, @@ -573,28 +564,7 @@ getOptionalInnerOuterPersistentBufferBatches( threads_per_block += warp_size; inner_batch = ceilDiv(after_vectorization, threads_per_block); } - // The maximum feature size can be processed without register spills and - // fusion segmentation for fp16 is 14K. Here, we can allow register spills to - // avoid fusion segmentation by incrase maximum batch size by 3. This allows - // us to process up to 20K features (14K + 256*8*3). - // Performance on A100-80G: - // (1) shape= 16384 x 16384, 1300 GB/s, time_us mean(var)= 1245.08 (8.89703), - // 64 bytes stack frame, 64 bytes spill stores, 128 bytes spill loads. (2) - // shape= 16384 x 18432, 1070 GB/s, time_us mean(var)= 1683.87 (19.527), 192 - // bytes stack frame, 192 bytes spill stores, 384 bytes spill loads. - // (3) shape= 16384 x 20480, 730 GB/s time_us mean(var)= 2766.64 (12.3883), - // 320 bytes stack frame, 320 bytes spill stores, 640 bytes spill loads. As a - // ref, the segmented version takes time_us mean(var)= 2841.91 (5.20231) - // without considering the overhead of fusion segmentation. - // (4) Disable this optimization if vectorize_factor is 1 due to high register - // usage in cases can't be vectorized. - const int64_t batch_max_reg_spill = - vectorize_factor > 1 ? batch_max + 3 : batch_max; - if (ignore_register_size_limit || inner_batch <= batch_max_reg_spill) { - return std::make_pair(inner_batch, threads_per_block); - } else { - return std::make_pair(std::nullopt, -1); - } + return std::make_pair(inner_batch, threads_per_block); } } // namespace @@ -647,22 +617,6 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( return false; } - // check if we can schedule the combined reductions with a reasonable - // batch size without register spills. - if (!getOptionalInnerOuterPersistentBufferBatches( - properties.total_reduction_numel, - properties.total_iteration_numel, - buffer_params.regs_buffer_size, - (int64_t)vectorize_factor, - warp_size, - false) - .first.has_value()) { - scheduler_debug_utils::canScheduleRejectReason( - heuristicType(), - "Required batch number is larger than available batch number! Will cause register spills!"); - return false; - } - const int64_t device_max_threads_per_multiprocessor = (int64_t)at::cuda::getCurrentDeviceProperties() ->maxThreadsPerMultiProcessor; @@ -805,35 +759,17 @@ std::shared_ptr innerOuterPersistentHeuristic( // Step-1, set InnerParams reduction dim: inner_vect, inner_batch, // threads_per_block (bdimx * bdimy). Start threads_per_block from a quarter - // warp, gradually increase it. Runtime checkCombinedReductionShape ensures - // inner_dim_numel is dividable by the multiplication of a quarter warp and - // vectorize_factor. + // warp, gradually increase it. iop.inner_vect = (int64_t)vectorize_factor; - // ignore_register_size_limit will return a valid batch size. - // This is needed because we enforced projection for fp32 if the feature size - // is less or equal 14K. It leads to register spills but still faster than the - // unprojected version due to the reuse of a input para in this grid - // persistent kernel. However, when we do register usage check in - // canScheduleRuntime, the enforced projection is not considered. Thus, - // max_persistent_buffer_size used here is larger than the value used in - // canScheduleRuntime. - // This is a tmp solution before we have a new persistent heuristics, where - // the projection is not solely based on size of buffers. The enforced buffer - // projection is not considered in canScheduleRuntime Thus, - constexpr bool ignore_register_size_limit = true; - const auto& batch_and_block_size = - getOptionalInnerOuterPersistentBufferBatches( + const auto [persistent_batch, threads_per_block] = + getBufferBatchSizeAndThreadsPerBlock( inner_dim_numel, outer_dim_numel, regs_buffer_size, iop.inner_vect, - dev_prop->warpSize, - ignore_register_size_limit); - auto opt_inner_batch = batch_and_block_size.first; - NVF_ERROR(opt_inner_batch.has_value()); - iop.inner_batch = opt_inner_batch.value(); - int64_t threads_per_block = batch_and_block_size.second; + dev_prop->warpSize); + iop.inner_batch = persistent_batch; NVF_ERROR( iop.inner_vect * iop.inner_batch * threads_per_block >= inner_dim_numel, From 2138fe254dbdfe45192fda3b33a023519608daac Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 22 Aug 2024 12:58:44 -0700 Subject: [PATCH 06/54] Add `chunk` to `ops/alias.h` for convenience. (#2818) --- csrc/ir/builder.cpp | 1 + csrc/ops/alias.cpp | 27 +++++++++++ csrc/ops/alias.h | 9 ++++ tests/cpp/test_alias.cpp | 6 +-- tests/cpp/test_dynamic_transform.cpp | 2 +- tests/cpp/test_move_split_cat.cpp | 26 +++++----- tests/cpp/test_resize.cpp | 72 +++++++++++++++++++++++++--- 7 files changed, 116 insertions(+), 27 deletions(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 34f358954c2..813c14e1eb3 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -413,6 +413,7 @@ Val* SimplifyingIrBuilder::divExpr(Val* lhs, Val* rhs) { if (rhs->isOneInt()) { return lhs; } + return IrBuilder::divExpr(lhs, rhs); } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 21406b4ce1d..4f895ac8697 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -844,4 +844,31 @@ TensorView* slice( return slice(inp, slices); } +std::vector chunk( + TensorView* in, + const int64_t chunks, + int64_t dim) { + NVF_CHECK(chunks > 0); + + const auto in_logical = TensorDomain::noReductions(in->getLogicalDomain()); + const auto num_dims = static_cast(in_logical.size()); + dim = wrapDim(dim, num_dims); + Val* dim_size = in_logical[dim]->extent(); + Val* slice_size = SimplifyingIrBuilder::ceilDivExpr( + dim_size, IrBuilder::create(chunks)); + + std::vector slices; + slices.reserve(chunks); + std::vector ranges(num_dims); + for (auto i : c10::irange(chunks)) { + ranges[dim].start = ranges[dim].stop; + ranges[dim].stop = + (i == chunks - 1 ? nullptr + : SimplifyingIrBuilder::mulExpr( + slice_size, IrBuilder::create(i + 1))); + slices.push_back(slice(in, ranges)); + } + return slices; +} + } // namespace nvfuser diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index bfa5c228144..91a93cf0243 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -135,4 +135,13 @@ NVF_API TensorView* slice( const std::vector& starts, const std::vector& stops); +// Splits `in`'s dimension `dim` into `chunks` chunks. All but the last chunk +// will be of size `ceil(dim_size/chunks)`. Unlike `torch.chunk` which returns +// only positive-size chunks and therefore may return fewer than `chunks` of +// them, this function returns exactly `chunks` chunks and a chunk of negative +// size will lead to a concretization error. This difference is because that we +// can't precompute the number of positive-size chunks when the dimension size +// is symbolic. +std::vector chunk(TensorView* in, int64_t chunks, int64_t dim); + } // namespace nvfuser diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index b50cd5c490e..0a8a427a6b8 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -263,11 +263,7 @@ TEST_F(AliasTest, SliceViewPermute) { TensorView* in = makeContigConcreteTensor({batches, seq_length, features * 3}); fusion->addInput(in); - std::vector splits({ - slice(in, {0, 0, 0}, {batches, seq_length, features}), - slice(in, {0, 0, features}, {batches, seq_length, features * 2}), - slice(in, {0, 0, features * 2}, {batches, seq_length, features * 3}), - }); + std::vector splits = chunk(in, /*chunks=*/3, /*dim=*/-1); for (TensorView* split : splits) { split = reshape( split, diff --git a/tests/cpp/test_dynamic_transform.cpp b/tests/cpp/test_dynamic_transform.cpp index 8086d4f5035..4f337e79a07 100644 --- a/tests/cpp/test_dynamic_transform.cpp +++ b/tests/cpp/test_dynamic_transform.cpp @@ -94,7 +94,7 @@ TEST_F(NVFuserTest, DynamicTransform1_CUDA) { auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); }, - ::testing::ThrowsMessage(::testing::HasSubstr( + ::testing::ThrowsMessage(::testing::HasSubstr( "Values of -1 passed to reshape must be constant at definition"))); } diff --git a/tests/cpp/test_move_split_cat.cpp b/tests/cpp/test_move_split_cat.cpp index a620e69f3a6..e3aadda17b8 100644 --- a/tests/cpp/test_move_split_cat.cpp +++ b/tests/cpp/test_move_split_cat.cpp @@ -51,9 +51,8 @@ TEST_F(MoveSplitCatTest, Noncancellable_DifferentOrder) { FusionGuard fg(fusion.get()); TensorView* in = makeContigConcreteTensor({2, 6}); - TensorView* s0 = slice(in, {0, 0}, {2, 3}); - TensorView* s1 = slice(in, {0, 3}, {2, 6}); - TensorView* out = cat({s1, s0}, /*dim=*/-1); + std::vector slices = chunk(in, /*chunks=*/2, /*dim=*/-1); + TensorView* out = cat({slices[1], slices[0]}, /*dim=*/-1); fusion->addInput(in); fusion->addOutput(out); @@ -73,10 +72,9 @@ TEST_F(MoveSplitCatTest, Cancellable_SetWithoutPermute) { FusionGuard fg(fusion.get()); TensorView* in = makeContigConcreteTensor({2, 5}); - TensorView* s0 = slice(in, {0, 0}, {2, 2}); - TensorView* s1 = slice(in, {0, 2}, {2, 5}); - s0 = set(s0); - s1 = set(s1); + std::vector slices = chunk(in, /*chunks=*/2, /*dim=*/-1); + TensorView* s0 = set(slices[0]); + TensorView* s1 = set(slices[1]); TensorView* out = cat({s0, s1}, /*dim=*/-1); fusion->addInput(in); @@ -183,10 +181,9 @@ TEST_F(MoveSplitCatTest, Cancellable_IncompatibleAllocationOrder) { FusionGuard fg(fusion.get()); TensorView* in = makeContigConcreteTensor({2, 3, 5}); - TensorView* s0 = slice(in, {0, 0, 0}, {2, 3, 2}); - TensorView* s1 = slice(in, {0, 0, 2}, {2, 3, 5}); - s0 = permute(s0, {1, 0, 2}); - s1 = permute(s1, {1, 0, 2}); + std::vector slices = chunk(in, /*chunks=*/2, /*dim=*/-1); + TensorView* s0 = permute(slices[0], {1, 0, 2}); + TensorView* s1 = permute(slices[1], {1, 0, 2}); TensorView* out = cat({s0, s1}, /*dim=*/-1); out->setAllocationDomain({out->axis(2), out->axis(0), out->axis(1)}, true); @@ -247,11 +244,10 @@ TEST_F(MoveSplitCatTest, Noncancellable_WrongAxis) { FusionGuard fg(fusion.get()); TensorView* in = makeContigConcreteTensor({2, 2, 4}); - TensorView* s0 = slice(in, {0, 0, 0}, {2, 2, 2}); - TensorView* s1 = slice(in, {0, 0, 2}, {2, 2, 4}); + std::vector slices = chunk(in, /*num_slices=*/2, /*dim=*/-1); // dim=2 is the split dimension. - s0 = permute(s0, {1, 2, 0}); - s1 = permute(s1, {1, 2, 0}); + TensorView* s0 = permute(slices[0], {1, 2, 0}); + TensorView* s1 = permute(slices[1], {1, 2, 0}); // After permutation, dim=1 is the split dimension. However, the following // `cat` is along dim=0. TensorView* out = cat({s0, s1}, /*dim=*/0); diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 0cc1f85ce18..f8b2a2ecb02 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -26,8 +26,10 @@ namespace nvfuser { using ResizeTest = NVFuserTest; using testing::Each; +using testing::HasSubstr; using testing::Not; using testing::Property; +using testing::ThrowsMessage; using testing::UnorderedElementsAre; // Simple pad test @@ -2218,8 +2220,8 @@ TEST_F(ResizeTest, FusionSqueezeSymbolic) { EXPECT_THAT( [&]() { fec.runFusionWithInputs({t0, 10}); }, - ::testing::ThrowsMessage(::testing::HasSubstr( - "must concretize to IterType::Broadcast but found"))); + ThrowsMessage( + HasSubstr("must concretize to IterType::Broadcast but found"))); } // See https://github.com/NVIDIA/Fuser/issues/365 @@ -2840,8 +2842,8 @@ TEST_F(ResizeTest, Slice3DVectorizeManual1) { EXPECT_THAT( [&]() { fe.runFusion(aten_inputs); }, - ::testing::ThrowsMessage(::testing::HasSubstr( - "with word size 2 not possible due to invalid stride"))); + ThrowsMessage( + HasSubstr("with word size 2 not possible due to invalid stride"))); } // Similar to Slice3DVectorizeManual2 but with a middle broadcast @@ -2883,8 +2885,8 @@ TEST_F(ResizeTest, Slice3DVectorizeManual2) { EXPECT_THAT( [&]() { fe.runFusion(aten_inputs); }, - ::testing::ThrowsMessage(::testing::HasSubstr( - "with word size 4 not possible due to invalid stride"))); + ThrowsMessage( + HasSubstr("with word size 4 not possible due to invalid stride"))); } // Repro of issue 540 without transpose @@ -3505,4 +3507,62 @@ TEST_F(ResizeTest, Issue2552) { fec.fusion(), out_tensors, {x_tensor, y_tensor}, __LINE__, __FILE__); } +TEST_F(ResizeTest, Chunk_NegativeSize) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigTensor(1); + fusion->addInput(in); + std::vector outs = chunk(in, /*chunks=*/6, /*dim=*/0); + for (auto* out : outs) { + fusion->addOutput(out); + } + + FusionExecutorCache fec(std::move(fusion)); + EXPECT_THAT( + [&]() { + auto in_tensor = at::randn({13}).cuda(); + fec.runFusionWithInputs({in_tensor}); + }, + ThrowsMessage(HasSubstr("Invalid resized domain extent"))); +} + +TEST_F(ResizeTest, Chunk_SizeZero) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigTensor(1); + fusion->addInput(in); + std::vector outs = chunk(in, /*chunks=*/6, /*dim=*/0); + for (auto* out : outs) { + fusion->addOutput(out); + } + + FusionExecutorCache fec(std::move(fusion)); + auto in_tensor = at::randn({15}).cuda(); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_EQ(out_tensors.back().numel(), 0); +} + +TEST_F(ResizeTest, Chunk_Uneven) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigTensor(1); + fusion->addInput(in); + std::vector outs = chunk(in, /*chunks=*/6, /*dim=*/0); + for (auto* out : outs) { + fusion->addOutput(out); + } + + FusionExecutorCache fec(std::move(fusion)); + auto in_tensor = at::randn({16}).cuda(); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + EXPECT_EQ(out_tensors.back().numel(), 1); +} + } // namespace nvfuser From 2af87e1eb4f697e4a369cc59682374876ea06b8f Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Thu, 22 Aug 2024 18:59:40 -0400 Subject: [PATCH 07/54] Revert "Promote hoisting of vectorize predicates" (#2742) check issue #2741 Regression is recovered after revert, see [dashboard](http://nv/eh4). --- csrc/device_lower/analysis/index_compute.cpp | 10 ++-------- tests/cpp/test_gpu3.cpp | 21 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/csrc/device_lower/analysis/index_compute.cpp b/csrc/device_lower/analysis/index_compute.cpp index 252595aca7a..f69f5a29f33 100644 --- a/csrc/device_lower/analysis/index_compute.cpp +++ b/csrc/device_lower/analysis/index_compute.cpp @@ -383,14 +383,8 @@ IndexingParameters getPredicateInitialIndexParameters( } else { // Similar to the above, loop_id()->extent() is // used here instead of loop->stop(). See the above comment. - if (unswitch_pred) { - loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( - loop_id->extent(), GpuLower::current()->kernel()->oneVal()); - } else { - // For vectorize, zero should be fine as well and that would - // promote better index hoisting - loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); - } + loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( + loop_id->extent(), GpuLower::current()->kernel()->oneVal()); } // When predicating a loop at the maximum end, predicate diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 908526f352a..7c9b16b392e 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -5032,6 +5032,15 @@ TEST_F(NVFuserTest, FusionPropagateVectorizePredicate_CUDA) { auto cond_inputs = InputsOf::output(cond); auto index_it = std::find(cond_inputs.begin(), cond_inputs.end(), loop_index); + auto vec_factor_it = + std::find_if(cond_inputs.begin(), cond_inputs.end(), [](Val* inp) { + auto int_val = inp->value(); + return int_val.hasValue() && + (int_val.as() == vec_factor - 1 || + int_val.as() == -(vec_factor - 1)); + }); + // If vectorized, the predicate should use (vec_factor - 1) or + // -(vec_factor - 1) rather than the loop index. if (vectorized_) { NVF_CHECK( index_it == cond_inputs.end(), @@ -5039,6 +5048,12 @@ TEST_F(NVFuserTest, FusionPropagateVectorizePredicate_CUDA) { loop_index->toInlineString(), " in ", cond->toInlineString()); + NVF_CHECK( + vec_factor_it != cond_inputs.end(), + "Expected to have ", + vec_factor - 1, + " in ", + cond->toInlineString()); } else { NVF_CHECK( index_it != cond_inputs.end(), @@ -5046,6 +5061,12 @@ TEST_F(NVFuserTest, FusionPropagateVectorizePredicate_CUDA) { loop_index->toInlineString(), " in ", cond->toInlineString()); + NVF_CHECK( + vec_factor_it == cond_inputs.end(), + "Not expected to have ", + vec_factor - 1, + " in ", + cond->toInlineString()); } } } From b10f1b2fd9bde995961ddcbd5c3e5f164089b884 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Fri, 23 Aug 2024 09:23:01 -0400 Subject: [PATCH 08/54] use id model in extent substitution (#2668) While investigating #2702, I found we can also revise `exactMappedExtentSubstitution` using IdModel which detects more replacements than `ExactLogicalDomainMap`. **Before this PR**, `exactMappedExtentSubstitution` uses `ExactLogicalDomainMap`: For a fusion with: ``` Inputs: T0_g[ iS0{i0} ], float T1_g[ iS1{i2} ], float Outputs: T4_g[ iS10{32}, iS11{( ceilDiv(i0, 32) )} ], float %kernel_math { T2_l[ iS4{32}rf, iS5{( ceilDiv(i0, 32) )}rf ] = view( T0_g[ iS0{i0} ] ) T3_l[ iS8{32}rf, iS9{( ceilDiv(i2, 32) )}rf ] = view( T1_g[ iS1{i2} ] ) T4_g[ iS10{32}, iS11{( ceilDiv(i0, 32) )} ] = T2_l[ iS4{32}rf, iS5{( ceilDiv(i0, 32) )}rf ] + T3_l[ iS8{32}rf, iS9{( ceilDiv(i2, 32) )}rf ]; } // %kernel_math ``` ``` ExactLogicalDomainMap: disjoint sets{ { iS0{i0}; iS3{i0}rf } { iS1{i2}; iS7{i2}rf } { iS9{( ceilDiv(i2, 32) )}rf; iS5{( ceilDiv(i0, 32) )}rf; iS11{( ceilDiv(i0, 32) )} } { iS8{32}rf; iS4{32}rf; iS10{32} } } ``` **After this PR**, `exactMappedExtentSubstitution` uses `exact_graph.disjointValSets()` ``` val_sets: disjoint sets{ { iS10{32}; iS4{32}rf; iS8{32}rf } { iS11{( ceilDiv(i0, 32) )}; iS5{( ceilDiv(i0, 32) )}rf; iS9{( ceilDiv(i2, 32) )}rf } { iS3{i0}rf; iS0{i0}; iS7{i2}rf; iS1{i2} } } ``` **The difference is `exact_graph.disjointValSets()` can detect `{i0} <=>{i2} ` while `ExactLogicalDomainMap` can't.** fusion is modified to: ``` Inputs: T0_g[ iS0{i0} ], float T1_g[ iS13{i0} ], float Outputs: T4_g[ iS10{32}, iS11{( ceilDiv(i0, 32) )} ], float %kernel_math { T2_l[ iS4{32}rf, iS5{( ceilDiv(i0, 32) )}rf ] = view( T0_g[ iS0{i0} ] ) T3_l[ iS8{32}rf, iS9{( ceilDiv(i0, 32) )}rf ] = view( T1_g[ iS13{i0} ] ) T4_g[ iS10{32}, iS11{( ceilDiv(i0, 32) )} ] = T2_l[ iS4{32}rf, iS5{( ceilDiv(i0, 32) )}rf ] + T3_l[ iS8{32}rf, iS9{( ceilDiv(i0, 32) )}rf ]; } // %kernel_math ``` --------- Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> --- .../exact_mapped_extent_substitution.cpp | 39 ++++++++++++------- tests/cpp/test_preseg_passes.cpp | 32 +++++++++++++++ 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/csrc/preseg_passes/exact_mapped_extent_substitution.cpp b/csrc/preseg_passes/exact_mapped_extent_substitution.cpp index 666076efc90..82c3a0bad77 100644 --- a/csrc/preseg_passes/exact_mapped_extent_substitution.cpp +++ b/csrc/preseg_passes/exact_mapped_extent_substitution.cpp @@ -6,11 +6,11 @@ */ // clang-format on #include +#include #include #include #include #include - namespace nvfuser::preseg_passes { namespace { @@ -30,15 +30,21 @@ void exactMappedExtentSubstitution(Fusion* fusion) { // map non-const extents to const extents std::unordered_map replacement_map; - const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); - // Loop over each exact root domain set - for (const auto& set_ptr : mapped_sets.disjointSets()) { + // Build the exact graph + IdModel id_model(fusion, false, false, false); + id_model.buildExactGraph(); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + const DisjointSets& val_sets = exact_graph.disjointValSets(); + + // Loop over each set of values + for (const auto& set_ptr : val_sets.disjointSets()) { // (1) pick a const extent // (2) if no const extent, pick the var with the lowest name() Val* const_extent = nullptr; Val* lowest_val = nullptr; - for (auto id : *set_ptr) { - if (isNonSubstitutableID(id)) { + for (auto v : *set_ptr) { + auto id = dynamic_cast(v); + if (id == nullptr || isNonSubstitutableID(id)) { continue; } // find the const extent, if already seen, check if they are the same @@ -60,8 +66,9 @@ void exactMappedExtentSubstitution(Fusion* fusion) { } // replace with const extents. // if no const extents, replace with the one with the lowest name. - for (auto id : *set_ptr) { - if (isNonSubstitutableID(id)) { + for (auto v : *set_ptr) { + auto id = dynamic_cast(v); + if (id == nullptr || isNonSubstitutableID(id)) { continue; } replacement_map.emplace( @@ -77,17 +84,23 @@ void exactMappedExtentSubstitution(Fusion* fusion) { void ExactMappedExtentSubstitutionPass::runPass(Fusion* fusion) { if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << "ExactLogicalDomainMap before " << name() << ":" << std::endl; - const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); - debug() << mapped_sets.toString() << std::endl; + debug() << "DisjointSets before " << name() << ":" << std::endl; + IdModel id_model(fusion, /*build_graphs=*/false); + id_model.buildExactGraph(); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + const DisjointSets& val_sets = exact_graph.disjointValSets(); + debug() << val_sets.toString() << std::endl; } exactMappedExtentSubstitution(fusion); if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { debug() << "ExactLogicalDomainMap after " << name() << ":" << std::endl; - const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); - debug() << mapped_sets.toString() << std::endl; + IdModel id_model(fusion, false, false, false); + id_model.buildExactGraph(); + const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + const DisjointSets& val_sets = exact_graph.disjointValSets(); + debug() << val_sets.toString() << std::endl; } } diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index ac941ea9132..feefef8e124 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -642,4 +642,36 @@ TEST_F(PresegTest, ReplaceOutput) { testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); } +TEST_F(PresegTest, ExtentSubstitution) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const std::vector input_shape = {128}; + const std::vector group_shape = {32, 4}; + auto tv0 = makeContigTensor(input_shape.size()); + auto tv1 = makeContigTensor(input_shape.size()); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = reshape(tv0, input_shape, group_shape); + auto tv3 = reshape(tv1, input_shape, group_shape); + auto tv4 = add(tv2, tv3); + fusion->addOutput(tv4); + + OptimizationPass::runPass(fusion.get()); + // two inputs should be same after ExactMappedExtentSubstitutionPass in + // OptimizationPass + const auto& inputs = fusion.get()->inputs(); + TensorView* input1 = dynamic_cast(inputs.at(0)); + TensorView* input2 = dynamic_cast(inputs.at(1)); + auto extend1 = input1->getLogicalDomain().at(0)->extent(); + auto extend2 = input2->getLogicalDomain().at(0)->extent(); + EXPECT_EQ(extend1, extend2); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn(input_shape, options); + auto t1 = at::randn(input_shape, options); + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1}); + testValidate( + executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); +} } // namespace nvfuser::preseg_passes From 58aa9759ffbe322fa3b9425859a44be27876d8c9 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Fri, 23 Aug 2024 08:34:07 -0700 Subject: [PATCH 09/54] Sharded SDPA backwards support (#2826) Adds temporary support for sharded backwards scaled dot product attention. Until https://github.com/NVIDIA/Fuser/issues/2563 is completed. Similar to https://github.com/NVIDIA/Fuser/pull/2565 Similar restrictions: 1. All necessary sharded inputs are manually sharded before the SDPABwdOp is created. We cannot rely on sharding propagation or sharding after the Fusion is created, because the dimension checks are called when the op is created. 2. Only the head dimension is sharded and all inputs and outputs have either a sharded head dimension or unshaded. 3. DID axis is the outermost axis. This is because during evaluation if we see 5 dimensions, it is assumed the first is the DID axis and is appropriately squeezed from the inputs and unsqueezed onto the outputs. --- csrc/ir/nodes.cpp | 17 ++- csrc/logical_domain_map.cpp | 37 +++--- csrc/multidevice/utils.cpp | 4 + csrc/ops/composite.cpp | 30 ++++- csrc/preseg_passes/propagate_shardings.cpp | 5 + tests/cpp/test_sdpa_node.cpp | 140 ++++++++++++++++++++- 6 files changed, 205 insertions(+), 28 deletions(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 6ce16a55113..3d3dc47d4e5 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4862,9 +4862,17 @@ std::vector SdpaBwdOp::evaluate( const std::vector& inputs) const { // Backward tensor inputs: grad_input, query, key, value, output, logsumexp, // max_q/k + // Temporary handling of DID parallelization. See + // https://github.com/NVIDIA/Fuser/issues/2563 + bool first_dim_is_did = this->key()->as()->axis(0)->isDeviceDim(); std::vector bwd_inputs; for (auto idx : c10::irange(6)) { - bwd_inputs.emplace_back(inputs.at(idx).as()); + auto in_tensor = inputs.at(idx).as(); + // Removing the size 1 from sharded axis from tensors. + if (first_dim_is_did) { + in_tensor = in_tensor.squeeze(0); + } + bwd_inputs.push_back(in_tensor); } const auto dropout_p = inputs.at(6).as(); const auto is_causal = inputs.at(7).as(); @@ -4918,6 +4926,13 @@ std::vector SdpaBwdOp::evaluate( return output.slice(-1, 0, last_dim_size); }; + // Add device dimension back to outputs. + if (first_dim_is_did) { + grad_query = grad_query.unsqueeze(0); + grad_key = grad_key.unsqueeze(0); + grad_value = grad_value.unsqueeze(0); + } + return { slice_last_dim(grad_query), slice_last_dim(grad_key), diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 184baf8aa68..6e67ff84b63 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -263,16 +263,16 @@ std::unordered_map PairwiseLogicalDomainMap::map( if (SdpaBwdOp* op = dynamic_cast(consumer_tv_->definition())) { // Producers: - // grad_attn = [N, H, L, Ev] - // query = [N, H, L, E] - // key = [N, H, S, E] - // value = [N, H, S, Ev] - // attn_out = [N, H, L, Ev] - // logsumexp = [N, H, L] + // grad_attn = [DIDx(D)? N, H, L, Ev] + // query = [DIDx(D)? N, H, L, E] + // key = [DIDx(D)? N, H, S, E] + // value = [DIDx(D)? N, H, S, Ev] + // attn_out = [DIDx(D)? N, H, L, Ev] + // logsumexp = [DIDx(D)? N, H, L] // Consumers: - // grad_query = [N, H, L, E] - // grad_key = [N, H, S, E] - // grad_value = [N, H, S, Ev] + // grad_query = [DID(D)? N, H, L, E] + // grad_key = [DID(D)? N, H, S, E] + // grad_value = [DID(D)? N, H, S, Ev] bool producer_has_s = producer_tv_->sameAs(op->key()) || producer_tv_->sameAs(op->value()); @@ -284,19 +284,18 @@ std::unordered_map PairwiseLogicalDomainMap::map( bool consumer_has_e = consumer_tv_->sameAs(op->grad_query()) || consumer_tv_->sameAs(op->grad_key()); + size_t num_device_dim = + !producer_logical.empty() && producer_logical.at(0)->isDeviceDim() ? 1 + : 0; for (auto idx : c10::irange(producer_logical.size())) { - if (idx < 2) { - // Map N, H from all producers to consumers + // Map N, H from all producers to consumers + // producer/consumer[2] = L/S + // producer/consumer[3] = E/Ev + if ((idx < 2 + num_device_dim) || + (idx == 2 + num_device_dim && producer_has_s == consumer_has_s) || + (idx == 3 + num_device_dim && producer_has_e == consumer_has_e)) { updatePairwiseLogicalDomainMap( producer_logical.at(idx), consumer_root.at(idx)); - } else if (idx == 2 && (producer_has_s == consumer_has_s)) { - // producer/consumer[2] = L/S - updatePairwiseLogicalDomainMap( - producer_logical.at(2), consumer_root.at(2)); - } else if (idx == 3 && (producer_has_e == consumer_has_e)) { - // producer/consumer[3] = E/Ev - updatePairwiseLogicalDomainMap( - producer_logical.at(3), consumer_root.at(3)); } } return dom_map; diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 1bd0dfae7e8..b5b8c7f1725 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -140,6 +140,10 @@ int64_t numDeviceDims(const TensorView* tv) { bool haveDifferentShardings( const TensorView* producer, const TensorView* consumer) { + // cpu scalars are not required to have a mesh + if (producer->isCpuScalar() || consumer->isCpuScalar()) { + return false; + } // exit early in the unsharded case for performance if (!producer->hasDeviceMesh() && !consumer->hasDeviceMesh()) { return false; diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 1a3d33b2e0c..ac03ddd8294 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -565,15 +565,35 @@ SdpfaBwdResult sdpfa_bwd( auto key_domain = TensorDomain::noReductions(key->getLogicalDomain()); auto value_domain = TensorDomain::noReductions(value->getLogicalDomain()); + // Temporary handling of DID parallelization see + // https://github.com/NVIDIA/Fuser/issues/2563 + bool has_device_dim = (query_domain.size() == 5); + if (has_device_dim) { + auto check_first_is_did = [](const std::vector& ids) -> void { + NVF_CHECK( + ids[0]->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + }; + check_first_is_did(query_domain); + check_first_is_did(key_domain); + check_first_is_did(value_domain); + check_first_is_did(grad_output->getLogicalDomain()); + check_first_is_did(output->getLogicalDomain()); + } + + auto concrete_query_size = TensorDomain::noDevices(query_domain).size(); + auto concrete_key_size = TensorDomain::noDevices(key_domain).size(); + auto concrete_value_size = TensorDomain::noDevices(value_domain).size(); + NVF_CHECK( - query_domain.size() == 4 && key_domain.size() == 4 && - value_domain.size() == 4, + concrete_query_size == 4 && concrete_key_size == 4 && + concrete_value_size == 4, "Expected query, key, and value to be 4D but got: ", - query_domain.size(), + concrete_query_size, " ", - key_domain.size(), + concrete_key_size, " ,and ", - value_domain.size()); + concrete_value_size); NVF_CHECK( !dropout_p || dropout_p->isScalar(), diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 60567c37fb0..e3e4f39f8d9 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -48,6 +49,10 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { } }; + if (tv->isCpuScalar()) { + continue; + } + if (tv->hasDeviceMesh()) { update_if_null(tv_with_mesh, tv); } else { diff --git a/tests/cpp/test_sdpa_node.cpp b/tests/cpp/test_sdpa_node.cpp index c388aa27b3c..772945f0909 100644 --- a/tests/cpp/test_sdpa_node.cpp +++ b/tests/cpp/test_sdpa_node.cpp @@ -166,6 +166,8 @@ void checkSdpaBwdMapping(Fusion* fusion, Expr* op) { std::vector producer_ids = producer_tv->getLogicalDomain(); std::vector consumer_ids = consumer_tv->getMaybeRootDomain(); + size_t num_device_dim = producer_ids.at(0)->isDeviceDim() ? 1 : 0; + // Idx=0: producer_ids[0], consumer_ids[0] = N // Idx=1: producer_ids[1], consumer_ids[1] = H // Idx=2: producer_ids[2], consumer_ids [2] = L/S @@ -182,19 +184,21 @@ void checkSdpaBwdMapping(Fusion* fusion, Expr* op) { consumer_tv->sameAs(sdpa_op->grad_key()); for (auto idx : c10::irange(consumer_ids.size())) { - if (idx < 2) { + if (idx < 2 + num_device_dim) { checkMapped(vg, producer_ids.at(idx), consumer_ids.at(idx)); EXPECT_TRUE(compute_at_map.areMapped( producer_ids.at(idx), consumer_ids.at(idx), IdMappingMode::EXACT)); - } else if (idx == 2 && (producer_has_s == consumer_has_s)) { + } else if ( + idx == (2 + num_device_dim) && (producer_has_s == consumer_has_s)) { checkMapped(vg, producer_ids.at(idx), consumer_ids.at(idx)); EXPECT_TRUE(compute_at_map.areMapped( producer_ids.at(idx), consumer_ids.at(idx), IdMappingMode::EXACT)); - } else if (idx == 3 && (producer_has_e == consumer_has_e)) { + } else if ( + idx == (3 + num_device_dim) && (producer_has_e == consumer_has_e)) { checkMapped(vg, producer_ids.at(idx), consumer_ids.at(idx)); EXPECT_TRUE(compute_at_map.areMapped( producer_ids.at(idx), @@ -826,6 +830,136 @@ TEST_F(SDPATest, Sharded_SdpaFwd) { validateSdpaFwdOutputs(nvf_out, aten_out); } +// TODO: Remove/update when https://github.com/NVIDIA/Fuser/issues/2563 is +// resolved. +TEST_F(SDPATest, Sharded_SdpaBwd) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + constexpr int64_t d = 4; + auto mesh = DeviceMesh::createForNumDevices(d); + std::vector q_shape({d, n, h / d, l, e}); + std::vector kv_shape({d, n, h / d, s, e}); + std::vector attn_shape({d, n, h / d, l, e}); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor q = at::randn({n, h / d, l, e}, options); + at::Tensor k = at::randn({n, h / d, s, e}, options); + at::Tensor v = at::randn({n, h / d, s, e}, options); + + constexpr double dropout_p = 0.2; + constexpr bool is_causal = false; + double scale = 1.0 / std::sqrt(e); + + auto + [output, + log_sumexp, + cum_seq_q, + cum_seq_k, + query_seq_len, + key_seq_len, + philox_seed, + philox_offset, + debug_attn_mask] = + at::_scaled_dot_product_flash_attention( + q, + k, + v, + dropout_p, + is_causal, + /*return_debug_mask=*/false, + scale); + + auto tv_grad_output = makeConcreteTensor(attn_shape, DataType::Half); + auto tvq = makeConcreteTensor(q_shape, DataType::Half); + auto tvk = makeConcreteTensor(kv_shape, DataType::Half); + auto tvv = makeConcreteTensor(kv_shape, DataType::Half); + auto tv_output = makeConcreteTensor(attn_shape, DataType::Half); + auto tv_logsumexp = makeConcreteTensor({d, n, h / d, l}, DataType::Float); + auto tv_seed = makeConcreteTensor({}, DataType::Int); + auto tv_offset = makeConcreteTensor({}, DataType::Int); + + fusion->addInput(tv_grad_output); + fusion->addInput(tvq); + fusion->addInput(tvk); + fusion->addInput(tvv); + fusion->addInput(tv_output); + fusion->addInput(tv_logsumexp); + fusion->addInput(tv_seed); + fusion->addInput(tv_offset); + + for (TensorView* tv : + {tvq, tvk, tvv, tv_grad_output, tv_output, tv_logsumexp}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + + auto tvgrad = sdpfa_bwd( + tv_grad_output, + tvq, + tvk, + tvv, + tv_output, + tv_logsumexp, + /*dropout_p=*/IrBuilder::create(dropout_p), + /*is_causal=*/IrBuilder::create(is_causal), + tv_seed, + tv_offset, + /*scale=*/nullptr); + + for (TensorView* tv : + {tvgrad.grad_query, tvgrad.grad_key, tvgrad.grad_value}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + fusion->addOutput(tv); + } + + checkSdpaBwdMapping(fusion.get(), tvgrad.grad_query->definition()); + + at::Tensor grad_out = at::randn({n, h / d, l, e}, options); + + std::vector sdpa_bwd_inputs = { + grad_out.unsqueeze(0), + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + output.unsqueeze(0), + log_sumexp.unsqueeze(0), + philox_seed, + philox_offset}; + + FusionExecutorCache fec(std::move(fusion)); + auto out = fec.runFusionWithInputs(sdpa_bwd_inputs); + + auto [ref_grad_query, ref_grad_key, ref_grad_value] = + at::_scaled_dot_product_flash_attention_backward( + grad_out, + q, + k, + v, + output, + log_sumexp, + cum_seq_q, + cum_seq_k, + /*max_q=*/*query_seq_len.maybe_as_int(), + /*max_k=*/*key_seq_len.maybe_as_int(), + dropout_p, + is_causal, + philox_seed, + philox_offset, + /*scale=*/scale); + + testValidate( + fec.fusion(), + out, + sdpa_bwd_inputs, + {ref_grad_query.unsqueeze(0), + ref_grad_key.unsqueeze(0), + ref_grad_value.unsqueeze(0)}, + __LINE__, + __FILE__); +} + TEST_F(SDPATest, ComputeAt) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); auto fusion = std::make_unique(); From 6dba9a837deb14b82bb87db8c8e2a07fb02cad60 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 23 Aug 2024 15:32:21 -0700 Subject: [PATCH 10/54] Support contiguous indexing (#2752) This PR enables contiguous indexing in the new indexer. It's only about indexing of tensor expressions. Predicates will be addressed as a follow-up PR. The basic idea is when there's a merge of two contiguous allocation domains, we can use the merge output domain to index the 2D space represented by the two allocation domains. This analysis exists for the existing indexer, mostly in csrc/contiguity.cpp. This PR attempts to reuse it as much as possible with minimal changes. The main difference between the existing and the new indexing methods is how the index propagation is done. In the existing method, it's always done through the iter domains of a tensor, whereas in the new method it's done over the AlmostExact graph. Accordingly, the contiguity analysis is only done from the root to loop domains of a tensor, but with the new indexer, it also needs to be done through the graph traversal path. That's the main adaptation done in this PR. In order to reuse the existing analysis, I created a new class `OrderedIdGroupInformation`, which is a subclass of `OrderedIdInformation`. For `ContigIDs`, I just created a new class `ContigIDGroups`, which is not a subclass of `ContigIDs`, but still pretty similar to `ContigIDs`. Once everything is moved to using the new indexer, we could clean up the existing analysis, but for now keeping these two classes seems to be necessary. Follow-up TODOs - Contig indexing for predicate - Contig analysis for both forward and backward traversal --- CMakeLists.txt | 1 + csrc/contiguity.cpp | 74 ++-- csrc/contiguity.h | 40 +- csrc/disjoint_set.h | 4 +- csrc/id_model/contiguity.cpp | 145 +++++++ csrc/id_model/contiguity.h | 158 +++++++ csrc/id_model/id_model_index_compute.cpp | 2 +- csrc/id_model/indexing.cpp | 114 ++++- csrc/id_model/indexing.h | 24 +- tests/cpp/test_indexing.cpp | 530 ++++++++++++++++++++--- 10 files changed, 986 insertions(+), 106 deletions(-) create mode 100644 csrc/id_model/contiguity.cpp create mode 100644 csrc/id_model/contiguity.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 244ac816c87..b043b7f1ca6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -137,6 +137,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/host_ir/executor.cpp ${NVFUSER_SRCS_DIR}/host_ir/host_ir.cpp ${NVFUSER_SRCS_DIR}/id_model/circular_buffer_indexing.cpp + ${NVFUSER_SRCS_DIR}/id_model/contiguity.cpp ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp ${NVFUSER_SRCS_DIR}/id_model/id_model_index_compute.cpp ${NVFUSER_SRCS_DIR}/id_model/indexing.cpp diff --git a/csrc/contiguity.cpp b/csrc/contiguity.cpp index 459f46a5f97..95333e139bd 100644 --- a/csrc/contiguity.cpp +++ b/csrc/contiguity.cpp @@ -14,11 +14,10 @@ namespace nvfuser { OrderedIdInformation::OrderedIdInformation( - const std::vector& ids, const std::vector& alloc_domain, std::shared_ptr concrete_info) : active_ids_(alloc_domain), concrete_info_(std::move(concrete_info)) { - if (ids.empty() || alloc_domain.empty()) { + if (alloc_domain.empty()) { return; } @@ -34,12 +33,14 @@ OrderedIdInformation::OrderedIdInformation( exclusively_consumes_allocs_.emplace(alloc_id); } +} +void OrderedIdInformation::traverseTo(const std::vector& ids) { // Iterate from the allocation domain to the provided ids and fill // consistently_ordered_ids_, id_to_alloc_ids_, and // exclusively_consumes_allocs_ for all the IDs auto exprs = StmtSort::getExprsBetween( - {alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()}); + {active_ids_.begin(), active_ids_.end()}, {ids.begin(), ids.end()}); for (auto expr : exprs) { OptInDispatch::dispatch(expr); @@ -53,7 +54,7 @@ bool OrderedIdInformation::checkExclusivelyConsumesAllocs(IterDomain* id) { id->toString(), " to be in the active ID set."); - auto alloc_id_it = id_to_alloc_ids_.find(id); + auto alloc_id_it = findAllocIDs(id); NVF_ERROR( alloc_id_it != id_to_alloc_ids_.end(), "Error replaying transforms in contiguous ID checker, couldn't find mapped allocs of ", @@ -68,7 +69,7 @@ bool OrderedIdInformation::checkExclusivelyConsumesAllocs(IterDomain* id) { continue; } - auto alloc_id_it = id_to_alloc_ids_.find(other_active_id); + auto alloc_id_it = findAllocIDs(other_active_id); NVF_ERROR( alloc_id_it != id_to_alloc_ids_.end(), "Error replaying transforms in contiguous ID checker, couldn't find mapped allocs of ", @@ -102,32 +103,25 @@ void OrderedIdInformation::handle(Merge* merge) { const auto& inner_alloc_ids = findAllocIDs(merge->inner())->second; const auto& outer_alloc_ids = findAllocIDs(merge->outer())->second; - // TODO: Concretization may prevent contiguous indexing or vectorization. - // It prevents contiguous indexing if the concretization is within the IDs - // that are used for indexing. - // For vectorization it just means we need to make sure the extents of the - // axes to the right of the broadcast allocation domain in the contigous - // merge is bigger than the vectorization dimension. And that the tensor - // buffer supports the vector word size (always done). - bool outer_is_concretized_bcast = merge->outer()->isBroadcast() && - concrete_info_->isConcretized(merge->outer()); - - bool inner_is_concretized_bcast = merge->inner()->isBroadcast() && - concrete_info_->isConcretized(merge->inner()); - // Update maps - // Find the position inner would have to have to be considered ordered + // Find the position inner would have to be considered ordered auto pos_after_outer = outer_pos + 1; for (; pos_after_outer < int64_t(active_ids_.size()); pos_after_outer++) { if (active_ids_[pos_after_outer] == nullptr) { // Can't be considered ordered after a nullptr break; } - if (active_ids_[pos_after_outer]->isReduction() || - ((active_ids_[pos_after_outer]->isBroadcast() && - !concrete_info_->isConcretized(active_ids_[pos_after_outer])))) { - // Skip reduction or broadcast axes that aren't concretized in the fusion - continue; + // When using IdModel, reduction domains are excluded from + // allocation domains but loop promotion may pick reduction + // domains, which should just be treated as normal domains. + if (!using_id_graph_) { + if (active_ids_[pos_after_outer]->isReduction() || + ((active_ids_[pos_after_outer]->isBroadcast() && + !isConcretized(active_ids_[pos_after_outer])))) { + // Skip reduction or broadcast axes that aren't concretized in the + // fusion + continue; + } } break; } @@ -142,8 +136,30 @@ void OrderedIdInformation::handle(Merge* merge) { // Inner could be a broadcast, so doesn't have to be right on // pos_after_outer as that ID (if it exists) should not be a broadcast. // However, merging over a broadcast should be fine. - inner_pos <= pos_after_outer && !inner_is_concretized_bcast && - !outer_is_concretized_bcast; + inner_pos <= pos_after_outer; + + if (!using_id_graph_) { + // TODO: Concretization may prevent contiguous indexing or vectorization. + // It prevents contiguous indexing if the concretization is within the IDs + // that are used for indexing. + // For vectorization it just means we need to make sure the extents of the + // axes to the right of the broadcast allocation domain in the contigous + // merge is bigger than the vectorization dimension. And that the tensor + // buffer supports the vector word size (always done). + // + // This shouldn't matter when using the IdModel-based + // indexer. When concretized, that should be reflected with the + // indexing path, and as long as the indexing path has a + // contiguous merge, its output should be safe to index. See + // also ContigIndexingTest.ConcretizedBroadcastMerge for a + // concrete example. + bool outer_is_concretized_bcast = + merge->outer()->isBroadcast() && isConcretized(merge->outer()); + bool inner_is_concretized_bcast = + merge->inner()->isBroadcast() && isConcretized(merge->inner()); + out_ordered = out_ordered && !inner_is_concretized_bcast && + !outer_is_concretized_bcast; + } if (out_ordered) { consistently_ordered_ids_.emplace(merge->out()); @@ -443,7 +459,7 @@ ContigIDs::ContigIDs( std::make_shared(ids[0]->fusion()); consistent_transform_info_ = std::make_unique( - ids, alloc_domain, concrete_info_); + OrderedIdInformation::get(ids, alloc_domain, concrete_info_)); } build(ids); } @@ -471,9 +487,7 @@ ContigIDs::ContigIDs( ignore_indexability_(ignore_indexability), ignore_consistent_ordering_(ignore_consistent_ordering), consistent_transform_info_(std::make_unique( - ids, - alloc_domain, - concrete_info_)), + OrderedIdInformation::get(ids, alloc_domain, concrete_info_))), non_divisible_id_info_(ids, alloc_domain, divisible_splits_) { build(ids); } diff --git a/csrc/contiguity.h b/csrc/contiguity.h index eaf39cc1e46..f64059c65a4 100644 --- a/csrc/contiguity.h +++ b/csrc/contiguity.h @@ -29,19 +29,21 @@ namespace nvfuser { // complex transformations. class OrderedIdInformation : public OptInDispatch { public: - OrderedIdInformation() = delete; - - OrderedIdInformation( + static OrderedIdInformation get( const std::vector& ids, const std::vector& alloc_domain, - std::shared_ptr concrete_info); + std::shared_ptr concrete_info) { + OrderedIdInformation info(alloc_domain, concrete_info); + info.traverseTo(ids); + return info; + } const std::unordered_map>& idToAllocIds() const { return id_to_alloc_ids_; } - bool isConsistentlyOrdered(IterDomain* id) const { + virtual bool isConsistentlyOrdered(IterDomain* id) const { return consistently_ordered_ids_.find(id) != consistently_ordered_ids_.end(); } @@ -51,7 +53,20 @@ class OrderedIdInformation : public OptInDispatch { exclusively_consumes_allocs_.end(); } - private: + virtual std::unordered_map>:: + const_iterator + findAllocIDs(IterDomain* id) const { + return id_to_alloc_ids_.find(id); + } + + protected: + OrderedIdInformation( + const std::vector& alloc_domain, + std::shared_ptr concrete_info = + nullptr); + + void traverseTo(const std::vector& ids); + // Returns if the id in active_ids should be in exclusively_consumes_allocs_ bool checkExclusivelyConsumesAllocs(IterDomain* id); @@ -65,7 +80,8 @@ class OrderedIdInformation : public OptInDispatch { void handle(Resize* resize) override; - auto findActiveId(IterDomain* id) const { + virtual std::vector::const_iterator findActiveId( + IterDomain* id) const { return std::find(active_ids_.begin(), active_ids_.end(), id); } @@ -79,11 +95,12 @@ class OrderedIdInformation : public OptInDispatch { return std::distance(active_ids_.begin(), it); } - auto findAllocIDs(IterDomain* id) const { - return id_to_alloc_ids_.find(id); + bool isConcretized(IterDomain* id) const { + NVF_ERROR(concrete_info_ != nullptr); + return concrete_info_->isConcretized(id); } - private: + protected: // Track which allocation ids were used to generate each iter domain std::unordered_map> id_to_alloc_ids_; @@ -126,6 +143,9 @@ class OrderedIdInformation : public OptInDispatch { // the domain is concretized within the local indexing, not in the entire // fusion. std::shared_ptr concrete_info_; + + // TODO: Temporary WAR to do ContigIDGroup-specific processing + bool using_id_graph_ = false; }; // Based on provided divisible split set, goes through expressions and marks all diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 92a6c4fb8e2..f989506d351 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -200,7 +200,7 @@ class VectorOfUniqueEntries { // Erase given entry from the containers if // there is a match. - void erase(T entry) { + int64_t erase(T entry) { vector_.erase( std::remove_if( vector_.begin(), @@ -208,7 +208,7 @@ class VectorOfUniqueEntries { [entry](T val) { return val == entry; }), vector_.end()); - set_.erase(entry); + return static_cast(set_.erase(entry)); } // Insert elements at the end of the container. diff --git a/csrc/id_model/contiguity.cpp b/csrc/id_model/contiguity.cpp new file mode 100644 index 00000000000..ecc9c74fc21 --- /dev/null +++ b/csrc/id_model/contiguity.cpp @@ -0,0 +1,145 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +namespace nvfuser { + +ContigIDGroups::ContigIDGroups( + const std::vector& alloc_domains, + std::vector contiguity, + const ExprPath& path_from_alloc, + const ValGraph& graph) + : graph_(graph), + alloc_domains_(alloc_domains), + alloc_contiguity_(std::move(contiguity)), + consistent_transform_info_( + std::make_unique( + OrderedIdGroupInformation::get( + alloc_domains, + path_from_alloc, + graph))) { + if (alloc_domains_.empty()) { + return; + } + + NVF_ERROR( + alloc_domains_.size() == alloc_contiguity_.size(), + "Arguments don't match ", + alloc_domains_.size(), + " != ", + alloc_contiguity_.size()); + + for (const auto index_domain_i : c10::irange(alloc_domains_.size())) { + IterDomain* index_domain = alloc_domains_.at(index_domain_i); + NVF_ERROR( + !index_domain->isBroadcast(), + "Broadcast domain should not be an index domain: ", + index_domain->toString()); + + alloc_to_contig_ids_[index_domain] = graph_.toGroup(index_domain); + + auto alloc_contiguity = alloc_contiguity_.at(index_domain_i); + + if (alloc_contiguity && + index_domain->getIterType() != IterType::GatherScatter) { + contig_ids_.emplace(graph_.toGroup(index_domain)); + } + } + + for (const auto& [eg, direction] : path_from_alloc) { + // Propagate resize and non-divisible dependencies + const auto inputs = direction == Direction::Forward + ? graph_.inputGroups(eg) + : graph_.outputGroups(eg); + const auto outputs = direction == Direction::Forward + ? graph_.outputGroups(eg) + : graph_.inputGroups(eg); + if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& inp) { + return resize_deps_.count(inp) > 0; + })) { + for (const auto& out : outputs) { + resize_deps_.insert(out); + } + } + + dispatch(eg, direction); + } +} + +void ContigIDGroups::handle(Merge* merge, Direction direction) { + // Only forward direction is supported for now + if (direction != Direction::Forward) { + return; + } + + // If output is not consistently ordered or doesn't solely consume all + // allocation domains in its dependencies, then it can't be a contiguously + // indexable iterdomain. + if (!consistent_transform_info_->isConsistentlyOrdered(merge->out())) { + return; + } + + if (!consistent_transform_info_->exclusivelyConsumesAllocs(merge->out())) { + return; + } + + // Check allocation domains for contiguity + auto alloc_ids_it = consistent_transform_info_->findAllocIDs(merge->out()); + VectorOfUniqueEntries alloc_ids = alloc_ids_it->second; + for (auto alloc_id_i : c10::irange(alloc_domains_.size())) { + auto alloc_id = alloc_domains_[alloc_id_i]; + if (alloc_ids.erase(alloc_id) == 0) { + continue; + } + auto alloc_contiguity = alloc_contiguity_.at(alloc_id_i); + // If we're indexing: + // we could still potentially consider this ID linearly indexable, as we + // could multiple the index by the last allocation's stride. See + // ContigIndexingTest.NonContigInnermost for a concrete example. + if (!alloc_contiguity && !alloc_ids.empty()) { + return; + } + } + + // Don't allow contig indexing after resize as we need traverse back + // at least to direct outputs of resize ops + if (resize_deps_.count(graph_.toGroup(merge->out()))) { + return; + } + + // Now we know merge->out is a contiguously indexable ID + + for (auto alloc_id : alloc_ids_it->second) { + alloc_to_contig_ids_[alloc_id] = graph_.toGroup(merge->out()); + } + + contig_ids_.emplace(graph_.toGroup(merge->out())); +} + +void ContigIDGroups::handle(Resize* resize, Direction direction) { + if (direction == Direction::Forward) { + resize_deps_.emplace(graph_.toGroup(resize->out())); + } else { + resize_deps_.emplace(graph_.toGroup(resize->in())); + } +} + +std::unordered_map getContigDomains( + const std::vector& alloc_domains, + const std::vector& alloc_contiguity, + const ExprPath& path_from_alloc, + const ValGraph& graph) { + ContigIDGroups contig_finder( + alloc_domains, alloc_contiguity, path_from_alloc, graph); + + return contig_finder.allocToContigIDs(); +} + +} // namespace nvfuser diff --git a/csrc/id_model/contiguity.h b/csrc/id_model/contiguity.h new file mode 100644 index 00000000000..5be5c8db3b2 --- /dev/null +++ b/csrc/id_model/contiguity.h @@ -0,0 +1,158 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +// Minimal adaptation of OrderedIdInformation for IdModel. Note that +// the analysis only propagates forward for now. +class OrderedIdGroupInformation : public OrderedIdInformation { + public: + // Run the ordering analysis from given allocation domains through + // a given traversal path + static OrderedIdGroupInformation get( + const std::vector& alloc_domain, + const ExprPath& path_from_alloc, + const ValGraph& graph) { + OrderedIdGroupInformation info(alloc_domain, graph); + info.traverse(path_from_alloc); + return info; + } + + // Traversal is based on the AlmostExact graph, so matching of iter + // domains also needs to be done with the same graph + bool isConsistentlyOrdered(IterDomain* id) const override { + return std::find_if( + consistently_ordered_ids_.begin(), + consistently_ordered_ids_.end(), + [&](IterDomain* consistent_id) -> bool { + return graph_.disjointValSets().strictAreMapped( + consistent_id, id); + }) != consistently_ordered_ids_.end(); + } + + std::unordered_map>:: + const_iterator + findAllocIDs(IterDomain* id) const override { + // This is a little ugly workaround. id_to_alloc_ids_ is a map of + // iter domains. If it were a map from ValGroup, this lookup + // should have been O(1) + return std::find_if( + id_to_alloc_ids_.begin(), + id_to_alloc_ids_.end(), + [&](const auto& kv) -> bool { + return graph_.disjointValSets().strictAreMapped(kv.first, id); + }); + } + + protected: + OrderedIdGroupInformation( + const std::vector& alloc_domain, + const ValGraph& graph) + : OrderedIdInformation(alloc_domain), graph_(graph) { + using_id_graph_ = true; + } + + // Currently only forward propagation is supported + void traverse(const ExprPath& path_from_alloc) { + for (const auto& [eg, direction] : path_from_alloc) { + if (direction == Direction::Backward) { + // TODO: support Backward prop + continue; + } + dispatch(eg->front()); + } + } + + std::vector::const_iterator findActiveId( + IterDomain* id) const override { + NVF_ERROR(id != nullptr); + auto it = std::find_if( + active_ids_.begin(), + active_ids_.end(), + [&](IterDomain* active_id) -> bool { + return active_id != nullptr && + graph_.disjointValSets().strictAreMapped(active_id, id); + }); + return it; + } + + private: + const ValGraph& graph_; +}; + +// Adapted from ContigIDs +class ContigIDGroups { + public: + ContigIDGroups( + const std::vector& alloc_domains, + std::vector contiguity, + const ExprPath& path_from_alloc, + const ValGraph& graph); + + void dispatch(const ExprGroup& eg, Direction direction) { + NVF_ERROR(!eg->empty()); + Expr* expr = eg->front(); + + // Currently not propagating any contiguity information with + // swizzles as contiguity is generally not preserved after swizzles. + // But in follow ups we could gradually add back a few special + // cases, depending on specific swizzle type and axes. + + if (auto merge = dynamic_cast(expr)) { + handle(merge, direction); + } else if (auto resize = dynamic_cast(expr)) { + handle(resize, direction); + } + } + + void handle(Merge* merge, Direction direction); + + void handle(Resize* resize, Direction direction); + + const std::unordered_set& contigIDs() const { + return contig_ids_; + } + + const std::unordered_map& allocToContigIDs() const { + return alloc_to_contig_ids_; + } + + private: + // Indexing traversal graph. + const ValGraph& graph_; + // Domains to analyze contiguity. They are typically allocation + // domains but if this is a predicate indexing pass, they are + // likely logical domains. + const std::vector alloc_domains_; + // Contiguity of alloc_domains_ + const std::vector alloc_contiguity_; + std::unique_ptr consistent_transform_info_; + + // Contig domain groups + std::unordered_set contig_ids_; + // Mapping of allocation domains to contig groups + std::unordered_map alloc_to_contig_ids_; + // All domains that have dependencies with resize ops + std::unordered_set resize_deps_; +}; + +// Get a contiguous indexing domain for a given allocation domain. If +// no such domain is found, just the allocation domain itself is +// returned. +std::unordered_map getContigDomains( + const std::vector& alloc_domains, + const std::vector& alloc_contiguity, + const ExprPath& path_from_alloc, + const ValGraph& graph); + +} // namespace nvfuser diff --git a/csrc/id_model/id_model_index_compute.cpp b/csrc/id_model/id_model_index_compute.cpp index 3d2d51a1e18..939b5ca0e13 100644 --- a/csrc/id_model/id_model_index_compute.cpp +++ b/csrc/id_model/id_model_index_compute.cpp @@ -40,7 +40,7 @@ void IdGraphIndexCompute::handle(Merge* merge) { auto outer_idx = getIndex(merge->outer()); auto inner_idx = getIndex(merge->inner()); auto out_idx = SimplifyingIrBuilder::addExpr( - SimplifyingIrBuilder::mulExpr(inner_ext, outer_idx), inner_idx); + SimplifyingIrBuilder::mulExpr(outer_idx, inner_ext), inner_idx); setIndex(merge->out(), out_idx); } else { auto out_idx = getIndex(merge->out()); diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index e5d2f2a5cf1..db8c5ec2053 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -385,6 +386,7 @@ class AllocationDomainSetup : private kir::IrVisitor { // allocation domains, which aren't relevant for indexing std::vector actual_allocation_domains; std::vector actual_strides; + std::vector actual_contiguity; for (const auto i : c10::irange(allocation_domains.size())) { auto allocation_domain = allocation_domains.at(i); auto promotion_domain = promoted_allocation_domains.at(i); @@ -395,9 +397,13 @@ class AllocationDomainSetup : private kir::IrVisitor { NVF_ERROR(stride != nullptr); actual_allocation_domains.push_back(promotion_domain); actual_strides.push_back(stride); + auto contig = contiguity.at(i); + NVF_ERROR(contig.has_value()); + actual_contiguity.push_back(contig.value()); } - return IndexingAllocationInfo{actual_allocation_domains, actual_strides}; + return IndexingAllocationInfo{ + actual_allocation_domains, actual_strides, actual_contiguity}; } // Reorder non-logical allocation domains to follow the ordering of @@ -919,20 +925,15 @@ Val* TensorIndexer::getLinearIndex( const auto alloc_info = getIndexingAllocationInfo(tv); - auto indices = getIndexFor( - expr, - as_consumer, - traversalGraph().toGroups(alloc_info.domains), - for_loops); - NVF_ERROR(indices.size() == alloc_info.domains.size()); + const auto [contig_indices, contig_strides] = + getContigIndexFor(expr, as_consumer, alloc_info, for_loops); // Linearize the indices with strides. - // TODO: Contiguous indexing Val* index = tv->fusion()->zeroVal(); - for (const auto i : c10::irange(alloc_info.domains.size())) { - Val* stride = alloc_info.strides.at(i); + for (const auto i : c10::irange(contig_indices.size())) { + Val* stride = contig_strides.at(i); index = SimplifyingIrBuilder::addExpr( - index, SimplifyingIrBuilder::mulExpr(indices.at(i), stride)); + index, SimplifyingIrBuilder::mulExpr(contig_indices.at(i), stride)); } // If a tensor is circular buffered, it also requires indexing of @@ -1232,4 +1233,95 @@ std::vector TensorIndexer::getPredicates( return info_vec; } +std::pair, std::vector> TensorIndexer:: + getContigDomainsAndStrides( + const IndexingAllocationInfo& alloc_info, + const ExprPath& traversal_path) const { + const std::unordered_map& contig_domains = + getContigDomains( + alloc_info.domains, + alloc_info.contiguity, + reverse(traversal_path), + traversalGraph()); + + // Find contiguous domains to index + std::unordered_set already_indexed_domains; + std::deque contig_alloc_groups; + std::deque contig_strides; + for (const auto i : c10::irange(alloc_info.domains.size())) { + // Traverse back from the innermost domains so that the right + // stride val is picked up for each contiguous domain + auto i1 = alloc_info.domains.size() - 1 - i; + IterDomain* allocation_domain = alloc_info.domains.at(i1); + auto contig_domains_it = contig_domains.find(allocation_domain); + NVF_ERROR( + contig_domains_it != contig_domains.end(), + "No contig domain mapping found for ", + allocation_domain->toString()); + + const ValGroup& contig_domain_group = contig_domains_it->second; + if (already_indexed_domains.find(contig_domain_group) != + already_indexed_domains.end()) { + continue; + } + already_indexed_domains.emplace(contig_domain_group); + + contig_alloc_groups.push_front(contig_domain_group); + contig_strides.push_front(alloc_info.strides.at(i1)); + } + + return { + {contig_alloc_groups.begin(), contig_alloc_groups.end()}, + {contig_strides.begin(), contig_strides.end()}}; +} + +std::pair, std::vector> TensorIndexer:: + getContigIndexFor( + const Expr* expr, + bool as_consumer, + const IndexingAllocationInfo& alloc_info, + const std::vector& for_loops) const { + const auto& index_groups = traversalGraph().toGroups(alloc_info.domains); + auto index_info = computeIndex(expr, index_groups, for_loops); + const auto& index_map = index_info.index_map; + const auto& replacement_map = getIndexReplacementMap( + expr, as_consumer, index_info.loop_domains, for_loops, index_map); + + std::vector contig_alloc_groups; + std::vector contig_strides; + + if (isContigIndexingEnabled()) { + const auto& contig_alloc_strides = + getContigDomainsAndStrides(alloc_info, index_info.traversal_path); + contig_alloc_groups = contig_alloc_strides.first; + contig_strides = contig_alloc_strides.second; + } else { + std::transform( + alloc_info.domains.begin(), + alloc_info.domains.end(), + std::back_inserter(contig_alloc_groups), + [&](IterDomain* allocation_domain) { + return traversalGraph().toGroup(allocation_domain); + }); + contig_strides = {alloc_info.strides.begin(), alloc_info.strides.end()}; + } + + std::vector result; + result.reserve(contig_alloc_groups.size()); + + for (const auto i : c10::irange(contig_alloc_groups.size())) { + const auto& contig_domain_group = contig_alloc_groups.at(i); + auto idx_it = index_map.find(contig_domain_group); + NVF_ERROR( + idx_it != index_map.end(), + "Index not found for ", + contig_domain_group->front()->toString()); + Val* idx = idx_it->second; + Val* replaced_idx = ir_utils::replaceValRecursively(idx, replacement_map); + result.push_back(replaced_idx); + } + + return {result, contig_strides}; +} + } // namespace nvfuser diff --git a/csrc/id_model/indexing.h b/csrc/id_model/indexing.h index 12e5d7f15ff..473eefabeea 100644 --- a/csrc/id_model/indexing.h +++ b/csrc/id_model/indexing.h @@ -12,10 +12,11 @@ #include #include #include +#include #include #include -// Just for RootPredicateInfo. Should be moved to its own header file +// Just for PredicateInfo. Should be moved to its own header file #include #include @@ -35,6 +36,7 @@ struct IndexingInfo { struct IndexingAllocationInfo { std::vector domains; std::vector strides; + std::vector contiguity; }; // The basic algorithm of indexing is: @@ -55,6 +57,10 @@ class TensorIndexer { // non-const reference TensorIndexer(IdModel& id_model); + bool isContigIndexingEnabled() const { + return !isOptionDisabled(DisableOption::ContigIndexing); + } + // Get a linear index of a given tensor appearing in a given expr, either // as a consumer or a producer. The predicate indexing will have a // separate interface. @@ -75,6 +81,13 @@ class TensorIndexer { const ValGroups& index_groups, const std::vector& loops) const; + // Get the contig indices of the given ID groups with their strides + std::pair, std::vector> getContigIndexFor( + const Expr* expr, + bool as_consumer, + const IndexingAllocationInfo& alloc_info, + const std::vector& loops) const; + // The AlmostExact graph is used since size-1 splits and merges // should not affect actual index exprs. // Returns non-const reference because indexing may create new domains and @@ -142,6 +155,15 @@ class TensorIndexer { // a broadcast-only loop group, should just use zero. bool shouldUseZeroIndex(const ValGroup& loop_group) const; + // For a given indexng traversal path toward allocation_domains, + // return the contiguous domains and their strides that can provide + // equivalent indexing results. + // + // Currently, only backward traversal is supported. + std::pair, std::vector> getContigDomainsAndStrides( + const IndexingAllocationInfo& alloc_info, + const ExprPath& traversal_path) const; + // Get a replace map for tensor indexing. Examples include replacing // an index of a vectorized loop with zero. // diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index 9bbb7a29a17..d2bb1e2d0a6 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -31,6 +31,7 @@ namespace nvfuser { using IndexingTest = NVFuserTest; using PredicateIndexingTest = NVFuserFixtureParamTest; +using ContigIndexingTest = NVFuserTest; namespace { @@ -270,7 +271,10 @@ class IndexValidator : public kir::IrVisitor { } template - static void validate(Fusion* fusion, Args... args) { + static void validate( + Fusion* fusion, + bool enable_contig_indexing, + Args... args) { EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set( EnableOption::IdModel, {"consumer_index", "producer_index"}); @@ -281,6 +285,9 @@ class IndexValidator : public kir::IrVisitor { DisableOptionsGuard::getCurOptions().set(DisableOption::IndexHoist); // Magic zero is not yet supported DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); + if (!enable_contig_indexing) { + DisableOptionsGuard::getCurOptions().set(DisableOption::ContigIndexing); + } GpuLower lower(fusion); @@ -402,7 +409,10 @@ class PredicateIndexValidator : public kir::IrVisitor { } template - static void validate(Fusion* fusion, Args... args) { + static void validate( + Fusion* fusion, + bool enable_contig_indexing, + Args... args) { EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set( EnableOption::IdModel, {"predicate"}); @@ -415,6 +425,9 @@ class PredicateIndexValidator : public kir::IrVisitor { DisableOptionsGuard::getCurOptions().set(DisableOption::MagicZero); DisableOptionsGuard::getCurOptions().set( DisableOption::PredicateElimination); + if (!enable_contig_indexing) { + DisableOptionsGuard::getCurOptions().set(DisableOption::ContigIndexing); + } GpuLower lower(fusion); @@ -522,7 +535,7 @@ TEST_F(IndexingTest, SimplePointwise1) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Almost same fusion as SimplePointwiseSerial but TID and BID @@ -609,7 +622,7 @@ TEST_F(IndexingTest, SimplePointwise2) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Simple reduction with no parallelization @@ -657,7 +670,7 @@ TEST_F(IndexingTest, SimpleReduction) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Reduction with inlining. Loop promotion picks a reduction domain, @@ -705,7 +718,7 @@ TEST_F(IndexingTest, PromotionToReductionDomain) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Fusion copied from AllocationDomainTest.TransposedIntermediate @@ -750,7 +763,7 @@ TEST_F(IndexingTest, AllocationDomain) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } TEST_F(IndexingTest, Reshape) { @@ -842,7 +855,7 @@ TEST_F(IndexingTest, Reshape) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Simple non-concretized broadcast @@ -879,7 +892,7 @@ TEST_F(IndexingTest, SimpleBroadcast1) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // SimpleBroadcast1 + scheduling @@ -939,7 +952,7 @@ TEST_F(IndexingTest, SimpleBroadcast2) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Concretized broadcast @@ -1002,7 +1015,7 @@ TEST_F(IndexingTest, SimpleBroadcast3) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Concretized broadcast with partial inlining. Loop promotion is @@ -1061,7 +1074,7 @@ TEST_F(IndexingTest, SimpleBroadcast4) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Trivial example. 1D shared tensor. Each device only has one @@ -1089,7 +1102,7 @@ TEST_F(IndexingTest, MultiDevice1DNoSplitMerge) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as MultiDevice1DNoSplitMerge but with split. @@ -1124,7 +1137,7 @@ TEST_F(IndexingTest, MultiDevice1DSplit) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } TEST_F(IndexingTest, MultiDevice2D) { @@ -1166,7 +1179,7 @@ TEST_F(IndexingTest, MultiDevice2D) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as MultiDevice2D but with loop allocation @@ -1209,7 +1222,7 @@ TEST_F(IndexingTest, MultiDevice2DLeafAllocation) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } TEST_F(IndexingTest, MultiDevice2DTranspose) { @@ -1258,7 +1271,7 @@ TEST_F(IndexingTest, MultiDevice2DTranspose) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Allocation of broadcast domains should not need to be promoted. @@ -1304,7 +1317,7 @@ TEST_F(IndexingTest, PromotedBroadcast) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Simple vectorized copy @@ -1364,7 +1377,7 @@ TEST_F(IndexingTest, SimpleVectorize) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Test for reorderAllocationDomains. The vectorized @@ -1435,7 +1448,7 @@ TEST_F(IndexingTest, NonInnermostVectorize) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Indexing traversal failure repro due to non-size-one broadcast @@ -1500,7 +1513,7 @@ TEST_F(IndexingTest, AlmostExactTraversalWithNonOneBroadcast) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } TEST_F(IndexingTest, Swizzle) { @@ -1545,7 +1558,7 @@ TEST_F(IndexingTest, Swizzle) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Simple Unroll test. Unlike Unswitch, Unroll moves up allocation @@ -1608,7 +1621,7 @@ TEST_F(IndexingTest, SimpleUnroll) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Unrolling with no unrolled loop domain @@ -1668,7 +1681,7 @@ TEST_F(IndexingTest, InlinedUnroll) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } TEST_F(IndexingTest, SmemAllocationDomainForTranspose) { @@ -1753,7 +1766,7 @@ TEST_F(IndexingTest, SmemAllocationDomainForTranspose) { StmtNameType smem_tv_name; }; - IndexValidator::validate(&fusion, tv3->name()); + IndexValidator::validate(&fusion, false, tv3->name()); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input0 = at::randn({256, 256}, options); @@ -1862,7 +1875,7 @@ TEST_F(IndexingTest, ResizePath) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as DoubleBufferingTest.DoubleBuffering1 @@ -1974,7 +1987,7 @@ TEST_F(IndexingTest, DoubleBuffering1) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as DoubleBufferingTest.DoubleBuffering4 @@ -2080,7 +2093,7 @@ TEST_F(IndexingTest, DoubleBuffering4) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as DoubleBufferingTest.DoubleBuffering6 @@ -2229,7 +2242,7 @@ TEST_F(IndexingTest, DoubleBuffering6) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as DoubleBuffering1 but with >2 stages @@ -2347,7 +2360,7 @@ TEST_F(IndexingTest, CircularBuffering1) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as DoubleBuffering6 but with >2 stages @@ -2502,7 +2515,7 @@ TEST_F(IndexingTest, CircularBuffering2) { } }; - IndexValidator::validate(&fusion); + IndexValidator::validate(&fusion, false); } // Same fusion as IndexingTest.SimplePointwise1 @@ -2558,7 +2571,7 @@ TEST_F(PredicateIndexingTest, SimplePointwise1) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); } // Testing predicate indexing with an rfactor reduction @@ -2638,7 +2651,7 @@ TEST_F(PredicateIndexingTest, ReductionRfactor) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); } // Same fusion as IndexingTest.SimpleUnroll @@ -2711,7 +2724,7 @@ TEST_F(PredicateIndexingTest, SimpleUnroll) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); } // Simple unswitch fusion. Unlike SimpleUnroll, it has multiple @@ -2794,7 +2807,7 @@ TEST_F(PredicateIndexingTest, SimpleUnswitch) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); } // Same fusion as IndexingTest.SimpleVectorize @@ -2861,7 +2874,7 @@ TEST_F(PredicateIndexingTest, SimpleVectorize) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); } // Same as IndexingTest.NonInnermostVectorize @@ -2935,7 +2948,7 @@ TEST_F(PredicateIndexingTest, NonInnermostVectorize) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); } // Same fusion as IndexingTest.DoubleBuffering1 @@ -3027,7 +3040,7 @@ TEST_F(PredicateIndexingTest, DoubleBuffering1) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); @@ -3126,7 +3139,7 @@ TEST_F(PredicateIndexingTest, CircularBuffering1) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); @@ -3293,7 +3306,7 @@ TEST_F(PredicateIndexingTest, UnrolledCircularBuffering) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); @@ -3374,7 +3387,7 @@ TEST_F(PredicateIndexingTest, UnswitchedCircularBuffering1) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({99}, options); @@ -3463,7 +3476,7 @@ TEST_F(PredicateIndexingTest, UnswitchedCircularBuffering2) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); @@ -3568,7 +3581,7 @@ TEST_P(PredicateIndexingTest, UnswitchedCircularBuffering3) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); @@ -3647,7 +3660,7 @@ TEST_F(PredicateIndexingTest, UnswitchedCircularBuffering4) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); // Running this fusion with the legacy indexer would result in an // error if run with compute-sanitizer. @@ -3741,7 +3754,7 @@ TEST_F(PredicateIndexingTest, NonDivisibleSplit1) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -3832,7 +3845,7 @@ TEST_F(PredicateIndexingTest, NonDivisibleSplitWithUnswitch) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -3927,7 +3940,7 @@ TEST_F(PredicateIndexingTest, NonDivisibleSplitWithCircularBuffering) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -4038,7 +4051,7 @@ TEST_F( } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -4123,7 +4136,7 @@ TEST_P(PredicateIndexingTest, UnswitchPredicateIssueRepro681) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); EnableOptionsGuard enable_options_guard; if (GetParam()) { @@ -4282,7 +4295,7 @@ TEST_F(PredicateIndexingTest, NonDivisibleSplitWithUnswitchAndBroadcast) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({5}, options); @@ -4405,7 +4418,7 @@ TEST_F(PredicateIndexingTest, UnswitchConsolidationDifferentThreading) { } }; - PredicateIndexValidator::validate(&fusion); + PredicateIndexValidator::validate(&fusion, false); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1000}, options); @@ -4422,4 +4435,419 @@ TEST_F(PredicateIndexingTest, UnswitchConsolidationDifferentThreading) { testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } +// Same fusion as SimplePointwise1 but with contig indexing +TEST_F(ContigIndexingTest, SimplePointwise) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + fusion.addOutput(tv2); + + tv2->flatten(); + tv2->split(0, 4); + + TransformPropagator propagator(tv2); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); + + tv1->inlineAt(1); + + // Because of contig indexing, the index of tv0 and tv2 should be + // just: i0 * 4 + i1. + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getLinearIndex(TensorView* tv, TensorView* maybe_consumer) + const override { + bool as_consumer = maybe_consumer == nullptr; + auto consumer_tv = as_consumer ? tv : maybe_consumer; + std::vector loop_indices = getLoopIndices(consumer_tv, indexer_); + switch (tv->name()) { + case 0: { + NVF_ERROR(!as_consumer); + return addExpr( + mulExpr(loop_indices.at(0), consumer_tv->axis(1)->extent()), + loop_indices.at(1)); + } + case 1: { + return loop_indices.at(1); + } + case 2: { + NVF_ERROR(as_consumer); + return addExpr( + mulExpr(loop_indices.at(0), consumer_tv->axis(1)->extent()), + loop_indices.at(1)); + } + default: + NVF_ERROR(false, "Unexpected tensor: ", tv->toString()); + break; + } + return nullptr; + } + }; + + IndexValidator::validate(&fusion, true); +} + +TEST_F(ContigIndexingTest, NonContigInnermost) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Innermost dimension is non contiguous but the other two + // dimensions are contiguous. + auto tv0 = TensorViewBuilder() + .ndims(3) + .dtype(DataType::Float) + .contiguity({true, true, false}) + .build(); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + // [I0, I1, I2] + tv1->merge(1); + // [I0, I1*I2] + + // Since the i1 contig flag is true, the merge is contiguous even + // though i2 is not contiguous. The producer index of tv0 should be: + // i0 * I0_stride + i1 * I2_stride. The stride of I0 should be + // calculated as I2_stride * I2_extent * I1_extent. + // + // As for tv1, since it's fully contiguous, it should also be i0 * + // I0_stride + i1. Here, I0_stride should be I2_extent * I1_extent. + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getLinearIndex(TensorView* tv, TensorView* maybe_consumer) + const override { + bool as_consumer = maybe_consumer == nullptr; + auto consumer_tv = as_consumer ? tv : maybe_consumer; + std::vector loop_indices = getLoopIndices(consumer_tv, indexer_); + switch (tv->name()) { + case 0: { + NVF_ERROR(!as_consumer); + auto i0_stride = mulExpr( + mulExpr( + IrBuilder::getItemExpr( + IrBuilder::getAttrExpr( + IrBuilder::metadataExpr(tv), "alloc_stride"), + IrBuilder::create(2, DataType::Int)), + tv->getLogicalDomain().at(2)->extent()), + tv->getLogicalDomain().at(1)->extent()); + auto i2_stride = IrBuilder::getItemExpr( + IrBuilder::getAttrExpr( + IrBuilder::metadataExpr(tv), "alloc_stride"), + IrBuilder::create(2, DataType::Int)); + return addExpr( + mulExpr(loop_indices.at(0), i0_stride), + mulExpr(loop_indices.at(1), i2_stride)); + } + case 1: { + NVF_ERROR(as_consumer); + return addExpr( + mulExpr( + loop_indices.at(0), + mulExpr( + consumer_tv->getLogicalDomain().at(2)->extent(), + consumer_tv->getLogicalDomain().at(1)->extent())), + loop_indices.at(1)); + } + default: + NVF_ERROR(false, "Unexpected tensor: ", tv->toString()); + break; + } + return nullptr; + } + }; + + IndexValidator::validate(&fusion, true); +} + +// Contig indexing with broadcast inlining +TEST_F(ContigIndexingTest, BroadcastInlining) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + + TransformPropagator propagator(tv4); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); + + inlineMost(); + + // t4 is indexed at the merge output domain, so its index should be + // just its sole loop index. t2 and t3 are fully inlined + // intermediate tensors, so their indices are just zero. Since t1 is + // contiguous, it's also just indexed with the loop index. t0, on + // the other hand, needs to back traverse the merge since its sole + // index domain corresponds to the inner merge input domain. + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getLinearIndex(TensorView* tv, TensorView* maybe_consumer) + const override { + bool as_consumer = maybe_consumer == nullptr; + auto consumer_tv = as_consumer ? tv : maybe_consumer; + std::vector loop_indices = getLoopIndices(consumer_tv, indexer_); + switch (tv->name()) { + case 0: { + NVF_ERROR(!as_consumer); + return modExpr( + loop_indices.at(0), tv->getLogicalDomain().at(0)->extent()); + } + case 1: { + NVF_ERROR(!as_consumer); + return loop_indices.at(0); + } + case 2: + case 3: + return tv->fusion()->zeroVal(); + case 4: { + NVF_ERROR(as_consumer); + return loop_indices.at(0); + } + default: + NVF_ERROR(false, "Unexpected tensor: ", tv->toString()); + break; + } + return nullptr; + } + }; + + IndexValidator::validate(&fusion, true); +} + +// Merge after resize is not allowed to do contig indexing even when +// the original input domains are contiguous. +TEST_F(ContigIndexingTest, Resize) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({11, 30}); + + NVF_CHECK(shape[1] % 2 == 0); + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = slice(tv0, {0, shape[1] / 2}, {shape[0], shape[1]}); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + // Contig merge + tv2->merge(0); + + TransformPropagator propagator(tv2); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); + + // All tensors except for tv0 are indexed at the output of the merge + // op, so their indices should be just loop_indices[0]. However, for + // tv0, since the merge follows a resize, indexing is done at the + // resize input domain. + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getLinearIndex(TensorView* tv, TensorView* maybe_consumer) + const override { + bool as_consumer = maybe_consumer == nullptr; + auto consumer_tv = as_consumer ? tv : maybe_consumer; + std::vector loop_indices = getLoopIndices(consumer_tv, indexer_); + switch (tv->name()) { + case 0: { + NVF_ERROR(!as_consumer); + auto id0 = mulExpr( + divExpr( + loop_indices.at(0), + consumer_tv->getLogicalDomain().at(1)->extent()), + tv->getLogicalDomain().at(1)->extent()); + auto resize = dynamic_cast( + consumer_tv->getLogicalDomain().at(1)->definition()); + NVF_ERROR(resize != nullptr); + auto id1 = subExpr( + modExpr( + loop_indices.at(0), + consumer_tv->getLogicalDomain().at(1)->extent()), + resize->leftExpand()); + return addExpr(id0, id1); + } + case 1: + case 2: + return loop_indices.at(0); + default: + NVF_ERROR(false, "Unexpected tensor: ", tv->toString()); + break; + } + return nullptr; + } + }; + + IndexValidator::validate(&fusion, true); +} + +// Contiguous tensor but merge order breaks contiguity +TEST_F(ContigIndexingTest, NonConsistentMerge) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + fusion.addOutput(tv1); + + tv1->merge(0, 2); + tv1->merge(0, 1); + + TransformPropagator propagator(tv1); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); + + // Make sure both tv0 and tv1 are indexed without contig indexing + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getLinearIndex(TensorView* tv, TensorView* maybe_consumer) + const override { + bool as_consumer = maybe_consumer == nullptr; + auto consumer_tv = as_consumer ? tv : maybe_consumer; + std::vector loop_indices = getLoopIndices(consumer_tv, indexer_); + + auto id0 = divExpr( + divExpr(loop_indices.at(0), tv->getLogicalDomain().at(1)->extent()), + tv->getLogicalDomain().at(2)->extent()); + auto id0_extent = mulExpr( + tv->getLogicalDomain().at(2)->extent(), + tv->getLogicalDomain().at(1)->extent()); + auto id1 = + modExpr(loop_indices.at(0), tv->getLogicalDomain().at(1)->extent()); + auto id1_extent = tv->getLogicalDomain().at(2)->extent(); + auto id2 = modExpr( + divExpr(loop_indices.at(0), tv->getLogicalDomain().at(1)->extent()), + tv->getLogicalDomain().at(2)->extent()); + return addExpr( + addExpr(mulExpr(id0, id0_extent), mulExpr(id1, id1_extent)), id2); + } + }; + + IndexValidator::validate(&fusion, true); +} + +TEST_F(ContigIndexingTest, ConcretizedBroadcastMerge) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [I0, I1] + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + // [I0, I1, I2] + auto tv1 = makeContigTensor(3); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, false, true}); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + tv3->merge(1, 2); + tv3->merge(0, 1); + + TransformPropagator propagator(tv3); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv3->axis(0)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3); + + tv2->setMemoryType(MemoryType::Shared); + + // tv2's broadcast domain is concretized. Previously, this would + // have prevented contig indexing. + + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getLinearIndex(TensorView* tv, TensorView* maybe_consumer) + const override { + // Only interested in tv2 here since that's the one that has a + // concretized broadcast domain + if (tv->name() != 2) { + return nullptr; + } + + bool as_consumer = maybe_consumer == nullptr; + auto consumer_tv = as_consumer ? tv : maybe_consumer; + std::vector loop_indices = getLoopIndices(consumer_tv, indexer_); + + // When indexed as a consumer, the second merge is a contig + // merge, so the index should be just threadIdx.x + if (as_consumer) { + return loop_indices.at(0); + } + + // When indexed as a producer of tv1, the loop domain has all + // the concrete domains merged, so it needs to be + // decomposed. Specifically, the loop domain, threadIdx.x, should be + // decomposed as: + // + // Index of the outer logical domain: tidx / (I1 * I2) + // Index of the inner logical domain: tidx % (I1 * I2) / I2 + // + // Since the allocation domain of t2 is (I0 * I1), the final + // index is (tidx / (I1 * I2) * I1 + tidx % (I1 * I2) / I2) + + auto logical0 = divExpr( + loop_indices.at(0), + mulExpr( + consumer_tv->getLogicalDomain().at(1)->extent(), + consumer_tv->getLogicalDomain().at(2)->extent())); + + auto logical1 = divExpr( + modExpr( + loop_indices.at(0), + mulExpr( + consumer_tv->getLogicalDomain().at(1)->extent(), + consumer_tv->getLogicalDomain().at(2)->extent())), + consumer_tv->getLogicalDomain().at(2)->extent()); + + auto alloc0 = addExpr( + mulExpr(logical0, tv->getLogicalDomain().at(1)->extent()), logical1); + + return alloc0; + } + }; + + IndexValidator::validate(&fusion, true); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({5, 6}, options); + auto t1 = at::randn({5, 6, 7}, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser From 0b5d8defff2ae3f87e44046dffec210922cc5beb Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 26 Aug 2024 11:43:46 -0700 Subject: [PATCH 11/54] Support getPerDimLogicalIndices (#2843) Adding `Index::getConsumerPerDimLogicalIndex` and `Index::getProducerPerDimLogicalIndex`. These are used, for example, when lowering `CatOp` (https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/pass/index.cpp#L1945). Only added one simple unit test since these are really trivial with the new indexer, whereas previously we used an ugly WAR like `allocateToLogicalDomainGuard`. --- csrc/index_compute.cpp | 41 +++++++++++++--- tests/cpp/test_indexing.cpp | 97 +++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 7 deletions(-) diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index f18e25b41e2..5eb58e3e2b4 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -1615,10 +1615,24 @@ std::vector Index::getConsumerPerDimLogicalIndex( TensorView* consumer_tv, const std::vector& loops, const std::unordered_set& rotated_loops) { - auto guard = ir_utils::allocateToLogicalDomainGuard(consumer_tv, false); - IndexFromIdGraph index_from_id_graph = - getTensorIndexFromIdGraph(loops, rotated_loops, consumer_tv); - return getConsumerAllocationIndices(consumer_tv, loops, index_from_id_graph); + if (!lower_utils::hasRootToLoopLinearTransformations(consumer_tv) || + (isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex) && + GpuLower::current()->isTensorIndexerEnabled())) { + const TensorIndexer& indexer = GpuLower::current()->tensorIndexer(); + ValGroups logical_indices = + indexer.traversalGraph().toGroups(consumer_tv->getLogicalDomain()); + return indexer.getIndexFor( + consumer_tv->definition(), + /*as_consumer=*/true, + logical_indices, + loops); + } else { + auto guard = ir_utils::allocateToLogicalDomainGuard(consumer_tv, false); + IndexFromIdGraph index_from_id_graph = + getTensorIndexFromIdGraph(loops, rotated_loops, consumer_tv); + return getConsumerAllocationIndices( + consumer_tv, loops, index_from_id_graph); + } } std::vector Index::getProducerPerDimLogicalIndex( @@ -1627,9 +1641,22 @@ std::vector Index::getProducerPerDimLogicalIndex( const std::vector& loops, const std::unordered_set& rotated_loops, const std::unordered_map& override_index) { - auto guard = ir_utils::allocateToLogicalDomainGuard(producer_tv, false); - return getProducerAllocationIndices( - producer_tv, consumer_tv, loops, rotated_loops, override_index); + if (!lower_utils::hasRootToLoopLinearTransformations(producer_tv) || + (isIdModelOptionEnabled(IdModelEnableOption::ProducerIndex) && + GpuLower::current()->isTensorIndexerEnabled())) { + const TensorIndexer& indexer = GpuLower::current()->tensorIndexer(); + ValGroups logical_indices = + indexer.traversalGraph().toGroups(producer_tv->getLogicalDomain()); + return indexer.getIndexFor( + consumer_tv->definition(), + /*as_consumer=*/false, + logical_indices, + loops); + } else { + auto guard = ir_utils::allocateToLogicalDomainGuard(producer_tv, false); + return getProducerAllocationIndices( + producer_tv, consumer_tv, loops, rotated_loops, override_index); + } } std::vector Index::getStrides(TensorView* tv) { diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index d2bb1e2d0a6..b1eb4ea80e1 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -4850,4 +4850,101 @@ TEST_F(ContigIndexingTest, ConcretizedBroadcastMerge) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } +TEST_F(IndexingTest, PerDimLogicalIndices) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({4, 8}); + fusion.addInput(tv0); + auto tv1 = reshape(tv0, {4, 8}, {32}); + fusion.addOutput(tv1); + + tv1->split(0, 4); + tv1->split(0, 128); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv1->axis(2)->parallelize(ParallelType::Unroll); + + auto validate_per_dim_indices = + [](const std::vector& exprs) -> std::vector { + class Validator : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + using kir::IrVisitor::dispatch; + + void handle(LoadStoreOp* ls) override { + // There should be only one expression of tv1 = Set(tv0). + NVF_ERROR(ls->in()->isA()); + auto tv0 = ls->in()->as()->view(); + NVF_ERROR(tv0->name() == 0); + + NVF_ERROR(ls->out()->isA()); + auto tv1 = ls->out()->as()->view(); + NVF_ERROR(tv1->name() == 1); + + auto indexer = GpuLower::current()->tensorIndexer(); + auto loop_indices = getLoopIndices(tv1, indexer); + + // The logical domains of tv0 and tv1 are [i0, i1] and + // [i0*i1], respectively. Since tv1 is split twice, the + // logical domain of tv1 is obtained by traversing them from + // the three loop iter domains. + auto tv1_logical_index = addExpr( + mulExpr( + addExpr( + mulExpr(loop_indices.at(0), createInt(128)), + loop_indices.at(1)), + createInt(4)), + loop_indices.at(2)); + + // The tv0 logical indices are obtained by traversing through + // the merge for the reshape op. + std::vector tv0_logical_indices{ + divExpr(tv1_logical_index, tv0->getLogicalDomain().at(1)->extent()), + modExpr( + tv1_logical_index, tv0->getLogicalDomain().at(1)->extent())}; + + // Check tv1 logical indices + auto actual_tv1_logial_indices = + Index::getConsumerPerDimLogicalIndex(tv1, for_loops_, {}); + ASSERT_EQ(actual_tv1_logial_indices.size(), 1); + EXPECT_TRUE(actual_tv1_logial_indices[0]->sameAs(tv1_logical_index)) + << "Validation failure of " << tv1->toString() << " as consumer" + << "\nRef: " << tv1_logical_index->toInlineString() + << "\nActual: " << actual_tv1_logial_indices[0]->toInlineString(); + + // Check tv0 logical indices + auto actual_tv0_logial_indices = + Index::getProducerPerDimLogicalIndex(tv0, tv1, for_loops_, {}); + ASSERT_EQ(actual_tv0_logial_indices.size(), tv0_logical_indices.size()); + for (const auto i : c10::irange(tv0_logical_indices.size())) { + EXPECT_TRUE( + actual_tv0_logial_indices[i]->sameAs(tv0_logical_indices[i])) + << "Validation failure of " << tv0->toString() << " as producer" + << "\nRef: " << tv0_logical_indices[0]->toInlineString() + << "\nActual: " << actual_tv0_logial_indices[i]->toInlineString(); + } + } + }; + + Validator validator; + validator.handle(exprs); + + return exprs; + }; + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + DisableOptionsGuard disable_options_guard; + DisableOptionsGuard::getCurOptions().set(DisableOption::ExprSimplify); + DisableOptionsGuard::getCurOptions().set(DisableOption::IndexHoist); + + GpuLower lower(&fusion); + lower.passes().insert( + lower.passes().end(), + {"validate_per_dim_indices", validate_per_dim_indices}); + lower.run(); +} + } // namespace nvfuser From 3b61042e2085b3fef3716cb04ae9c4df0067075d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:50:06 -0400 Subject: [PATCH 12/54] Re-write replaceSymbolicSizes using IdModel (#2714) This uses IdModel to implement `replaceSymbolicSizes`. Extents are replaced with a single representative from their exact graph ValGroup with the following precedence: 1. Constants are preferred 2. If no constants exist, prefer the extents of fusion inputs. 3. Ties are broken by choosing the scalar with the smallest `name()`. Fixes #2702. Fixes #2766 --------- Co-authored-by: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> --- csrc/device_lower/pass/replace_size.cpp | 255 ++++++++++++------------ tests/cpp/test_gpu3.cpp | 51 +++++ tests/cpp/test_gpu_outer_reduction.cpp | 37 ++++ tests/cpp/test_smem_reuse.cpp | 13 +- tests/python/test_normalization.py | 27 +++ tests/python/test_python_frontend.py | 49 +++++ 6 files changed, 303 insertions(+), 129 deletions(-) diff --git a/csrc/device_lower/pass/replace_size.cpp b/csrc/device_lower/pass/replace_size.cpp index 26cdc51a174..1e1bc8b9738 100644 --- a/csrc/device_lower/pass/replace_size.cpp +++ b/csrc/device_lower/pass/replace_size.cpp @@ -6,11 +6,13 @@ */ // clang-format on #include +#include #include #include #include #include #include +#include #include @@ -35,131 +37,91 @@ namespace { // concice there to pull out. May want to consider making this mapping its own // class especially as it may be useful during scheduling. std::unordered_map getSimplificationMap(Fusion* fusion) { - std::list> disjoint_root_sets; - std::unordered_map*> - id_to_disjoint_root_set; - - auto map_root_ids = [&disjoint_root_sets, &id_to_disjoint_root_set]( - IterDomain* id0, IterDomain* id1) { - if (id0->isBroadcast() || id1->isBroadcast()) { - return; - } - - if (id0->isGatherScatter() || id1->isGatherScatter()) { - return; - } + IdModel id_model(fusion, /*build_graphs=*/false); + id_model.buildExactGraph(); + ValGraph& graph = id_model.idGraph(IdMappingMode::EXACT); - auto disjoint_set_0_it = id_to_disjoint_root_set.find(id0); - auto disjoint_set_1_it = id_to_disjoint_root_set.find(id1); - bool set_0_found = disjoint_set_0_it != id_to_disjoint_root_set.end(); - bool set_1_found = disjoint_set_1_it != id_to_disjoint_root_set.end(); - - if (set_0_found && set_1_found) { - if (disjoint_set_0_it->second == disjoint_set_1_it->second) { - return; - } - // merge second disjoint set into first - auto* set_0 = disjoint_set_0_it->second; - auto* set_1 = disjoint_set_1_it->second; - for (auto id : *set_1) { - set_0->emplace(id); - id_to_disjoint_root_set[id] = set_0; - } - // remove second set from disjoint_root_sets - disjoint_root_sets.erase(std::find( - disjoint_root_sets.begin(), disjoint_root_sets.end(), *set_1)); - } else if (set_0_found || set_1_found) { - auto existing_set = - set_0_found ? disjoint_set_0_it->second : disjoint_set_1_it->second; - auto to_add_id = set_0_found ? id1 : id0; - existing_set->emplace(to_add_id); - id_to_disjoint_root_set[to_add_id] = existing_set; - // add entry into existing set - } else { - // create new set entry - disjoint_root_sets.emplace_back(); - auto* new_set = &disjoint_root_sets.back(); - new_set->emplace(id0); - new_set->emplace(id1); - id_to_disjoint_root_set[id0] = new_set; - id_to_disjoint_root_set[id1] = new_set; - } - }; - - auto fusion_vals = fusion->usedMathVals(); - for (auto producer_tv : ir_utils::filterByType(fusion_vals)) { - auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv); - for (auto consumer_tv : consumer_tvs) { - auto pairwise_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv); - auto c2p_logical_map = pairwise_map.mapConsumerToProducer(); - for (auto entry : c2p_logical_map) { - auto c_id = entry.first; - auto p_id = entry.second; - map_root_ids(p_id, c_id); + std::unordered_set fusion_input_ids; + for (Val* v : fusion->inputs()) { + if (auto* tv = dynamic_cast(v)) { + for (IterDomain* id : tv->getLogicalDomain()) { + fusion_input_ids.insert(id); } } } - // Map each set to an input ID (if it exists) that has the smallest ->name() - // entry value - std::unordered_map*, IterDomain*> - set_to_input_id; - - // Loop over the root domains, of the inputs to the fusion. Pick an input ID - // to use as the representative ID of the collected sets. Only consider inputs - // as those are the ones that map to values like "T0.size[1]". They are he - // ID's that propagated their extents into the problem. We could also check - // the outputs as we do have C++ examples of using output dimensions for the - // problem size instead of inputs. However, we don't do anything where we can - // translate to those kinds of kernels integrated into PyTorch. - for (auto input_tv : ir_utils::filterByType(fusion->inputs())) { - for (auto id : TensorDomain::noReductions(input_tv->getLogicalDomain())) { - auto id_set_it = id_to_disjoint_root_set.find(id); - if (id_set_it == id_to_disjoint_root_set.end()) { + std::unordered_map simplification_map; + + for (const ValGroup& group : graph.disjointValSets().disjointSets()) { + // For each ValGroup, find a single extent to use for all extents of + // IterDomains in the group. These are chosen in descending order of + // preference: + // 1. Constant ints. These might be non-immediate constants + // 2. Extents of input TVs. + // 3. Extents of non-input TVs. + // Within these three classes, we find the IterDomain with the smallest + // name(). + bool group_is_const = false; + IterDomain* rep = nullptr; + bool rep_is_input_id = false; + std::unordered_set dynamic_scalars; + for (Val* v : *group) { + auto* id = dynamic_cast(v); + NVF_ERROR( + id != nullptr, "Expected only IterDomains in exact graph ValGroups"); + bool is_input_id = fusion_input_ids.count(id) > 0; + if (rep == nullptr) { + rep = id; + rep_is_input_id = is_input_id; continue; } - auto* id_set = id_set_it->second; - if (set_to_input_id.find(id_set) == set_to_input_id.end()) { - set_to_input_id[id_set] = id; + Val* ext = id->extent(); + bool ext_is_const = ext->isConstInt(); + if (!ext_is_const) { + dynamic_scalars.insert(ext); + } + + if (ext_is_const) { + if (!group_is_const || id->name() < rep->name()) { + rep = id; + // This lets us avoid repeating the costly isConstInt check + group_is_const = true; + rep_is_input_id = is_input_id; + continue; + } + } else if (is_input_id) { + if (group_is_const) { + continue; + } + if (!rep_is_input_id || id->name() < rep->name()) { + rep = id; + rep_is_input_id = is_input_id; + continue; + } } else { - auto input_id_of_set = set_to_input_id.at(id_set); - // Swap id's if new name is less than previously set - bool swap_ids = id->name() < input_id_of_set->name(); - // If new id is a const scalar but previously was'nt use the const - // scalar - swap_ids = swap_ids || - (id->extent()->isConstScalar() && - !input_id_of_set->extent()->isConstScalar()); - // If previous scalar was const and new isn't, don't swap - swap_ids = swap_ids && - !(input_id_of_set->extent()->isConstScalar() && - !id->extent()->isConstScalar()); - - if (swap_ids) { - set_to_input_id[id_set] = id; + // id is a non-input TV + if (group_is_const || rep_is_input_id) { + continue; + } + if (id->name() < rep->name()) { + rep = id; + rep_is_input_id = is_input_id; + continue; } } } - } - - // Finally make map from ID extents to the representitive ID extent. - std::unordered_map extent_to_min_input_id_extent; - for (auto entry : set_to_input_id) { - auto* set = entry.first; - auto input_id = entry.second; - for (auto id : *set) { - auto prev_it = extent_to_min_input_id_extent.find(id->extent()); - // We loop in an unspecified order, so we might overwrite - // extent_to_min_input_id_extent[id->extent()]. For reproducibility's - // sake, only do so if it would lower the index of the mapped value. - if (prev_it != extent_to_min_input_id_extent.end() && - prev_it->second->name() <= input_id->extent()->name()) { - continue; + NVF_ERROR(rep != nullptr); + Val* rep_ext = rep->extent(); + for (Val* v : *group) { + auto* id = v->as(); + Val* ext = id->extent(); + // Don't remap constants or rep_ext itself + if (!ext->sameAs(rep_ext) && dynamic_scalars.count(ext)) { + simplification_map.emplace(ext, rep_ext); } - extent_to_min_input_id_extent[id->extent()] = input_id->extent(); } } - return extent_to_min_input_id_extent; + return simplification_map; } } // namespace @@ -228,22 +190,71 @@ void replaceSymbolicSizes(Fusion* fusion) { } } - // Use a minimal number of sizes from provided tensors. + // Simplify extents for each exact ValGroup in the fusion auto extent_simplification_map = getSimplificationMap(fusion); - for (auto extent_entry : extent_simplification_map) { - auto orig_extent = extent_entry.first; - auto simplified_extent = extent_entry.second; - if (tensor_dim_map.count(orig_extent)) { - if (tensor_dim_map.count(simplified_extent)) { - tensor_dim_map[orig_extent] = tensor_dim_map[simplified_extent]; - } else { - tensor_dim_map[orig_extent] = simplified_extent; - } + + // We now need to map replacement scalars to their targets in tensor_dim_map + // if they exist. To do this we compose extent_simplification_map with + // tensor_dim_map. + // + // Example: + // + // T0[ i0, i1 ] + // T1[ i2, i3 ] + // T2[ i4 ] + // T3 = T0 + T1 + // T4 = T2 * full({5}, 0) + // ... + // + // tensor_dim_map: + // i0 = getMetaData[T0].logical_size[0] + // i1 = getMetaData[T0].logical_size[1] + // i2 = getMetaData[T1].logical_size[0] + // i3 = getMetaData[T1].logical_size[1] + // i4 = getMetaData[T2].logical_size[0] + // + // extent_simplification_map: + // i2 = i0 + // i3 = i1 + // i4 = 5 + // + // In this loop, we update the _target_ values like so: + // + // extent_simplification_map (updated): + // i2 = getMetaData[T0].logical_size[0] + // i3 = getMetaData[T0].logical_size[1] + // i4 = 5 + // + // Note that i4's entry is not updated since i4 does not map to a key from + // tensor_dim_map. + for (auto& [orig_extent, simplified_extent] : extent_simplification_map) { + auto it = tensor_dim_map.find(simplified_extent); + if (it != tensor_dim_map.end()) { + // Update the mapped extent value + simplified_extent = it->second; + } + } + // Now add entries from tensor_dim_map, being careful not to overwrite + // existing replacements. + // + // Using the example from above, at this point extent_simplification_map is + // missing entries for i0 and i1, so we add those directly from + // tensor_dim_map: + // + // extent_simplification_map (updated): + // i0 = getMetaData[T0].logical_size[0] + // i1 = getMetaData[T0].logical_size[1] + // i2 = getMetaData[T0].logical_size[0] + // i3 = getMetaData[T0].logical_size[1] + // i4 = 5 + for (auto [tensor_dim, meta_expr] : tensor_dim_map) { + if (extent_simplification_map.count(tensor_dim) == 0) { + extent_simplification_map[tensor_dim] = meta_expr; } } // Run mutation on the fusion with the tensor_dim_map - ir_utils::replaceValue(fusion, tensor_dim_map); + ir_utils::replaceValue(fusion, extent_simplification_map); } } // namespace nvfuser diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 7c9b16b392e..61caf77ab1e 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -8770,6 +8771,56 @@ TEST_F(NVFuserTest, Issue2685Repro) { testValidate(&fusion_copy, outputs, inputs, __LINE__, __FILE__); } +// Check that extents are properly replaced by replaceSymbolicSizes lowering +// pass +TEST_F(NVFuserTest, ReplaceSymbolicSizes) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + auto tv2 = makeSymbolicTensor(1); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + auto tv3 = add(tv0, tv1); + auto tv4 = full( + {IrBuilder::create(5, DataType::Index)}, + IrBuilder::create(2.0, DataType::Float), + DataType::Float); + auto tv5 = mul(tv2, tv4); + + fusion->addOutput(tv3); + fusion->addOutput(tv5); + + replaceSymbolicSizes(fusion); + + // tv0's extents map to their corresponding getMetaData expressions + EXPECT_EQ( + tv0->axis(0)->extent()->toInlineString(), + "( (( (( getMetaData(T0) )).logical_size ))[0] )"); + EXPECT_EQ( + tv0->axis(1)->extent()->toInlineString(), + "( (( (( getMetaData(T0) )).logical_size ))[1] )"); + EXPECT_EQ( + tv1->axis(0)->extent()->toInlineString(), + "( (( (( getMetaData(T0) )).logical_size ))[0] )"); + EXPECT_EQ( + tv1->axis(1)->extent()->toInlineString(), + "( (( (( getMetaData(T0) )).logical_size ))[1] )"); + EXPECT_EQ( + tv3->axis(0)->extent()->toInlineString(), + "( (( (( getMetaData(T0) )).logical_size ))[0] )"); + EXPECT_EQ( + tv3->axis(1)->extent()->toInlineString(), + "( (( (( getMetaData(T0) )).logical_size ))[1] )"); + + EXPECT_EQ(tv2->axis(0)->extent()->toInlineString(), "5"); + EXPECT_EQ(tv5->axis(0)->extent()->toInlineString(), "5"); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/tests/cpp/test_gpu_outer_reduction.cpp b/tests/cpp/test_gpu_outer_reduction.cpp index 050d4423eec..f6c120c5aba 100644 --- a/tests/cpp/test_gpu_outer_reduction.cpp +++ b/tests/cpp/test_gpu_outer_reduction.cpp @@ -2559,4 +2559,41 @@ TEST_F(OuterReductionTest, IterGroupedMultipleReductions) { lparams); } +// Repro of https://github.com/NVIDIA/Fuser/pull/2766 +TEST_F(NVFuserTest, SmallOuterBlockReductionIssue2766) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + std::vector shape{100, 2, 128}; + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + + auto tv1 = reshape( + tv0, + {IrBuilder::create(shape[0]), + IrBuilder::create(shape[1]), + IrBuilder::create(shape[2])}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + + // Previously, after the extent replacement of the lowering, the reduction + // reference tensor got a reduction domain of a static size, which is just 1, + // but the pre-reshape tensors still kept using symbolic extents. Before + // https://github.com/NVIDIA/Fuser/pull/2714, the scheduler decided to not use + // TIDy because the reference tensor has a static size of 1, but since the + // other tensors still had dynamic sizes, it resulted in the dynamic + // allocation error. + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({shape[0] * shape[1], shape[2]}, options); + std::vector inputs({t0}); + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto outputs = fec.runFusionWithInputs(inputs); + + testValidate(fec.fusion(), outputs, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/cpp/test_smem_reuse.cpp b/tests/cpp/test_smem_reuse.cpp index 2e48d57795c..952d01caa97 100644 --- a/tests/cpp/test_smem_reuse.cpp +++ b/tests/cpp/test_smem_reuse.cpp @@ -652,8 +652,11 @@ TEST_F(SmemReuseTest, ExpandInterferes) { tv->setMemoryType(MemoryType::Shared); } - // tv3 is trying to reuse tv1's memory. however it has a concrete size. - // The reuse only happens when tv1 is also concrete. + // tv3 is trying to reuse tv1's memory. Even though tv3 has a concrete size + // and tv1 might not, tv1's extent will always be replaced by the constant + // since they are exact mapped (see the replaceSymbolicSizes lowering + // pass). Otherwise, we would only do this replacement when tv1 is also + // concrete. { bool t3_alias_t1 = false; GpuLower gpulw(fusion.get()); @@ -669,11 +672,7 @@ TEST_F(SmemReuseTest, ExpandInterferes) { } } } - if (is_concrete) { - EXPECT_TRUE(t3_alias_t1); - } else { - EXPECT_FALSE(t3_alias_t1); - } + EXPECT_TRUE(t3_alias_t1); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); diff --git a/tests/python/test_normalization.py b/tests/python/test_normalization.py index 2dcc4244a5b..c009b98d02a 100644 --- a/tests/python/test_normalization.py +++ b/tests/python/test_normalization.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +from nvfuser import FusionDefinition, DataType from nvfuser.contrib.nn.normalization import InstanceNorm3dNVFuser @@ -172,3 +173,29 @@ def forward(self, x): pred = model(x) loss = nn.functional.mse_loss(pred, y.float()) loss.backward() + + +# Test that split extents are properly replaced with constants +# See https://github.com/NVIDIA/Fuser/issues/2702 +def test_issue2702(): + def create_fusion(fd: FusionDefinition) -> None: + T4 = fd.define_tensor( + shape=[1, -1, -1, -1], + contiguity=[None, True, True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[3, 2, 1, 0], + ) + T75 = fd.ops.reshape(T4, new_shape=[1, 8, 4, 8192, 128]) + T90 = fd.ops.cast(T75, dtype=DataType.Float) + T91 = fd.ops.sum(T90, dims=[0, 2], keepdim=False, dtype=DataType.Null) + T92 = fd.ops.cast(T91, dtype=DataType.BFloat16) + fd.add_output(T92) + + with FusionDefinition() as fd: + create_fusion(fd) + + ins = [torch.randn((1, 32, 8192, 128), dtype=torch.bfloat16, device="cuda:0")] + outs = fd.execute(ins) + + torch.testing.assert_close(outs[0], ins[0].view(8, 4, 8192, 128).sum(1)) diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 431330dd43f..f52ee06aa65 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4280,3 +4280,52 @@ def fusion_func(fd: FusionDefinition): ), ): nvf_out = fd.execute([tensor_inp, 2.0 + 1.0j]) + + # Test that replaced sizes using input tensor metadata are successfully computed + # See https://github.com/NVIDIA/Fuser/pull/2714 which surfaced this in + # failing thunder test + # thunder.tests.test_core.test_bsym_toposort_nvfuser_cuda_thunder.dtypes.float32 + def test_replaced_sizes_pr2714(self): + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.ops.exp(T0) + T3 = fd.ops.tanh(T1) + S4 = fd.define_scalar(4, dtype=DataType.Int) + V5 = fd.define_vector([S4], dtype=DataType.Int) + T6 = fd.ops.reshape(T2, new_shape=V5) + S7 = fd.define_scalar(4, dtype=DataType.Int) + V8 = fd.define_vector([S7], dtype=DataType.Int) + T9 = fd.ops.reshape(T3, new_shape=V8) + T10 = fd.ops.add(T6, T9) + T11 = fd.ops.reciprocal(T0) + T12 = fd.ops.mul(T3, T11) + S13 = fd.define_scalar(2.00000, dtype=DataType.Double) + S14 = fd.ops.reciprocal(S13) + T15 = fd.ops.mul(T10, S14) + fd.add_output(T10) + fd.add_output(T12) + fd.add_output(T15) + + inputs = [ + torch.randn((4,), dtype=torch.float32, device="cuda:0").as_strided( + (2, 2), (2, 1) + ), + torch.randn((4,), dtype=torch.float32, device="cuda:0").as_strided( + (2, 2), (2, 1) + ), + ] + + self.exec_nvfuser(fusion_func, inputs) From 3cd58b0fef4de135f1d51cbab3c5cc480bd3729c Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:40:10 -0700 Subject: [PATCH 13/54] Rename files in `tests/python` (#2805) PR #2701 to allow creating modular test files and moves to pytest from unittest. Given these changes, we are doing the following renaming/reorganization: - Merge `pytest_utils.py` and `utils.py` since we are only using `pytest` now. - `pytest_core.py` -> `opinfo_core.py` - `pytest_framework.py` -> `opinfo_framework.py ` - `pytest_fusion_definitions.py` -> `opinfo_fusion_definitions.py` - `pytest_input_generators.py` -> `opinfo_input_generators.py` - `pytest_opinfos.py` -> `opinfos.py` - `pytest_ops.py` -> `test_ops.py`. (Allows automatic test file discovery with `pytest`) - `pytest.md` -> `README.md` so it will displayed on github page. Next PRs: - ~Remove other instances of `unittest`.~ (PR #2809 merged) - Update README with instructions on creating new tests, using serde check decorators, and the `NVFuserTest` class. - ~Extract matmul/linear and SDPA tests from `test_python_frontend.py` into a separate test file each.~ (PR #2806 merged) --- manual_ci.sh | 4 +- tests/python/{Pytest.md => README.md} | 22 +-- .../python/{pytest_core.py => opinfo_core.py} | 6 +- ...ytest_framework.py => opinfo_framework.py} | 2 +- ...itions.py => opinfo_fusion_definitions.py} | 4 +- ...nerators.py => opinfo_input_generators.py} | 4 +- .../python/{pytest_opinfos.py => opinfos.py} | 10 +- tests/python/pytest_utils.py | 178 ------------------ tests/python/{pytest_ops.py => test_ops.py} | 10 +- tests/python/utils.py | 174 ++++++++++++++++- tools/codediff/compare_codegen.sh | 2 +- 11 files changed, 204 insertions(+), 212 deletions(-) rename tests/python/{Pytest.md => README.md} (52%) rename tests/python/{pytest_core.py => opinfo_core.py} (96%) rename tests/python/{pytest_framework.py => opinfo_framework.py} (98%) rename tests/python/{pytest_fusion_definitions.py => opinfo_fusion_definitions.py} (97%) rename tests/python/{pytest_input_generators.py => opinfo_input_generators.py} (99%) rename tests/python/{pytest_opinfos.py => opinfos.py} (99%) delete mode 100644 tests/python/pytest_utils.py rename tests/python/{pytest_ops.py => test_ops.py} (97%) diff --git a/manual_ci.sh b/manual_ci.sh index 092929f7e8b..26e2cc1ce34 100755 --- a/manual_ci.sh +++ b/manual_ci.sh @@ -29,8 +29,8 @@ run_test './bin/tutorial' run_test './bin/test_python_frontend' run_test './bin/test_profiler' -run_test 'pytest tests/python/pytest_ops.py' -run_test 'python tests/python/test_python_frontend.py' +run_test 'pytest tests/python/test_ops.py' +run_test 'pytest tests/python/test_python_frontend.py' run_test 'pytest tests/python/test_schedule_ops.py' if $failed_tests; diff --git a/tests/python/Pytest.md b/tests/python/README.md similarity index 52% rename from tests/python/Pytest.md rename to tests/python/README.md index f77e6ff36af..9caa4fe156f 100644 --- a/tests/python/Pytest.md +++ b/tests/python/README.md @@ -8,10 +8,10 @@ ## Usage -* Run tests: `pytest python_tests/pytest_ops.py` -* Filter tests with `-k` option: `pytest python_tests/pytest_ops.py -k var_mean` -* Show all possible tests: `pytest python_tests/pytest_ops.py --collect-only` -* Filter all possible tests with `-k` option: `pytest python_tests/pytest_ops.py --collect-only -k var_mean` +* Run tests: `pytest python_tests/test_ops.py` +* Filter tests with `-k` option: `pytest python_tests/test_ops.py -k var_mean` +* Show all possible tests: `pytest python_tests/test_ops.py --collect-only` +* Filter all possible tests with `-k` option: `pytest python_tests/test_ops.py --collect-only -k var_mean` ## Dependencies * `pytest` @@ -19,14 +19,14 @@ ## Code Organization ### Files modified When Adding a New Op -* `pytest_opinfos.py`: Each operation corresponds to an OpInfo object -* `pytest_input_generators.py`: A set of correctness and error input generators are needed to create test cases for each operation. -* `pytest_fusion_definitions.py` (Less Frequent): A specific operation might need a unique `FusionDefinition` function in order to test the new operation and that function would be added in this file. +* `opinfos.py`: Each operation corresponds to an OpInfo object +* `opinfo_input_generators.py`: A set of correctness and error input generators are needed to create test cases for each operation. +* `opinfo_fusion_definitions.py` (Less Frequent): A specific operation might need a unique `FusionDefinition` function in order to test the new operation and that function would be added in this file. ### Structural Code Used By All Tests -* `pytest_core.py`: Contains the defintion of the `Opinfo` object. -* `pytest_framework.py`: Contains the decorator template to iterate over all ops for a given test case. -* `pytest_ops.py`: Defines correctness and error tests for `FusionDefinition` `definition` operations. +* `opinfo_core.py`: Contains the defintion of the `Opinfo` object. +* `opinfo_framework.py`: Contains the decorator template to iterate over all ops for a given test case. +* `test_ops.py`: Defines correctness and error tests for `FusionDefinition` `definition` operations. ### Misc -* `pytest_utils.py`: Common helper functions +* `utils.py`: Common helper functions diff --git a/tests/python/pytest_core.py b/tests/python/opinfo_core.py similarity index 96% rename from tests/python/pytest_core.py rename to tests/python/opinfo_core.py index c4873005c41..940a45b89ed 100644 --- a/tests/python/pytest_core.py +++ b/tests/python/opinfo_core.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Owner(s): ["module: nvfuser"] -from pytest_utils import ( +from utils import ( all_dtypes_except_reduced, ArgumentType, torch_to_python_dtype_map, @@ -50,12 +50,12 @@ def __repr__(self): return f"[SampleInput args={self.args} kwargs={self.kwargs}]" def jax(self): - from pytest_utils import JAX_AVAILABLE + from utils import JAX_AVAILABLE assert JAX_AVAILABLE import jax.numpy as jnp - from pytest_utils import torch_to_jax_dtype_map + from utils import torch_to_jax_dtype_map def to_jax(t): if isinstance(t, torch.Tensor): diff --git a/tests/python/pytest_framework.py b/tests/python/opinfo_framework.py similarity index 98% rename from tests/python/pytest_framework.py rename to tests/python/opinfo_framework.py index 0c1fa39ca77..61b0c8ac3a2 100644 --- a/tests/python/pytest_framework.py +++ b/tests/python/opinfo_framework.py @@ -6,7 +6,7 @@ import inspect import torch from typing import Callable -from pytest_utils import map_dtype_to_str +from utils import map_dtype_to_str import pytest diff --git a/tests/python/pytest_fusion_definitions.py b/tests/python/opinfo_fusion_definitions.py similarity index 97% rename from tests/python/pytest_fusion_definitions.py rename to tests/python/opinfo_fusion_definitions.py index 5e4d164584c..95abad9b7f4 100644 --- a/tests/python/pytest_fusion_definitions.py +++ b/tests/python/opinfo_fusion_definitions.py @@ -5,8 +5,8 @@ import torch -from pytest_core import OpInfo -from pytest_utils import ArgumentType, is_tensor +from opinfo_core import OpInfo +from utils import ArgumentType, is_tensor from nvfuser import FusionDefinition from nvfuser.pytorch_utils import ( diff --git a/tests/python/pytest_input_generators.py b/tests/python/opinfo_input_generators.py similarity index 99% rename from tests/python/pytest_input_generators.py rename to tests/python/opinfo_input_generators.py index 1118156c3eb..779f84bfb3b 100644 --- a/tests/python/pytest_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -12,8 +12,8 @@ import random from numbers import Number -from pytest_core import OpInfo, SampleInput, ErrorSample, Domain -from pytest_utils import ( +from opinfo_core import OpInfo, SampleInput, ErrorSample, Domain +from utils import ( make_number, find_nonmatching_dtype, is_floating_dtype, diff --git a/tests/python/pytest_opinfos.py b/tests/python/opinfos.py similarity index 99% rename from tests/python/pytest_opinfos.py rename to tests/python/opinfos.py index 5582f7ff406..e711a6ceae0 100644 --- a/tests/python/pytest_opinfos.py +++ b/tests/python/opinfos.py @@ -5,14 +5,14 @@ import math import torch -from pytest_core import OpInfo, ReferenceType, Domain -from pytest_fusion_definitions import ( +from opinfo_core import OpInfo, ReferenceType, Domain +from opinfo_fusion_definitions import ( api_test_fd_fn, tensor_input_fd_fn, tensor_api_test_fd_fn, vector_api_test_fd_fn, ) -from pytest_input_generators import ( +from opinfo_input_generators import ( broadcast_error_generator, broadcast_in_dim_generator, broadcast_in_dim_error_generator, @@ -51,7 +51,7 @@ linear_input_generator, linear_error_generator, ) -from pytest_utils import ( +from utils import ( bool_int_dtypes, complex_dtypes, full_precision_float_dtypes, @@ -62,7 +62,7 @@ ) from functools import partial -from pytest_utils import JAX_AVAILABLE +from utils import JAX_AVAILABLE if JAX_AVAILABLE: import jax diff --git a/tests/python/pytest_utils.py b/tests/python/pytest_utils.py deleted file mode 100644 index ff0eee99abe..00000000000 --- a/tests/python/pytest_utils.py +++ /dev/null @@ -1,178 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# Owner(s): ["module: nvfuser"] - -import pytest -import torch -from torch.testing import make_tensor -from typing import Optional -from functools import wraps -from enum import Enum, auto - -try: - # flake8: noqa - import jax - - JAX_AVAILABLE = True -except ImportError as e: - JAX_AVAILABLE = False - pass - - -def requiresJAX(fn): - @wraps(fn) - def _fn(*args, **kwargs): - if not JAX_AVAILABLE: - pytest.xfail("Requires JAX") - return fn(*args, **kwargs) - - return _fn - - -class ArgumentType(Enum): - # a symbolic value requires an input argument during kernel execution - Symbolic = auto() - # scalar with constant value - ConstantScalar = auto() - # python number - int, float, complex, bool - Constant = auto() - - -bool_dtypes = (torch.bool,) - -int_dtypes = ( - torch.int32, - torch.int64, -) - -half_precision_float_dtypes = ( - torch.bfloat16, - torch.float16, -) - -full_precision_float_dtypes = ( - torch.float32, - torch.float64, -) - -complex_dtypes = ( - torch.complex64, - torch.complex128, -) - -# Half-precision float dtypes bf16, fp16 are skipped because nvfuser upcasts those dtypes to fp32 -# but does not return the original type. -bool_int_dtypes = bool_dtypes + int_dtypes -float_dtypes = half_precision_float_dtypes + full_precision_float_dtypes -int_float_dtypes = int_dtypes + full_precision_float_dtypes -float_complex_dtypes = full_precision_float_dtypes + complex_dtypes -all_dtypes_except_reduced = int_dtypes + full_precision_float_dtypes + complex_dtypes -all_dtypes_except_bool = all_dtypes_except_reduced + half_precision_float_dtypes -all_dtypes = all_dtypes_except_bool + bool_dtypes - -map_dtype_to_str = { - torch.bool: "bool", - torch.uint8: "uint8", - torch.int8: "int8", - torch.int16: "int16", - torch.int32: "int32", - torch.int64: "int64", - torch.bfloat16: "bfloat16", - torch.float16: "float16", - torch.float32: "float32", - torch.float64: "float64", - torch.complex64: "complex64", - torch.complex128: "complex128", -} - -torch_to_jax_dtype_map = None -if JAX_AVAILABLE: - import jax.numpy as jnp - - torch_to_jax_dtype_map = { - torch.bool: jnp.bool_, - torch.uint8: jnp.uint8, - torch.int8: jnp.int8, - torch.int16: jnp.int16, - torch.int32: jnp.int32, - torch.int64: jnp.int64, - torch.bfloat16: jnp.bfloat16, - torch.float16: jnp.float16, - torch.float32: jnp.float32, - torch.float64: jnp.float64, - torch.complex64: jnp.complex64, - torch.complex128: jnp.complex128, - } - -torch_to_python_dtype_map = { - torch.bool: bool, - torch.uint8: int, - torch.int8: int, - torch.int16: int, - torch.int32: int, - torch.int64: int, - torch.bfloat16: float, - torch.float16: float, - torch.float32: float, - torch.float64: float, - torch.complex64: complex, - torch.complex128: complex, -} - - -def make_tensor_like(a): - # type: (torch.Tensor) -> torch.Tensor - """Returns a tensor with the same properties as the given tensor. - - Args: - a (torch.Tensor): The tensor to copy properties from. - - Returns: - torch.Tensor: A tensor with the same properties as :attr:`a`. - """ - return torch.testing.make_tensor( - a.shape, device=a.device, dtype=a.dtype, requires_grad=a.requires_grad - ) - - -def make_number( - dtype: torch.dtype, low: Optional[float] = None, high: Optional[float] = None -): - """Returns a random number with desired dtype - - Args: - dtype (torch.dtype): Desired dtype for number. - low (Optional[Number]): Sets the lower limit (inclusive) of the given range. - high (Optional[Number]): Sets the upper limit (exclusive) of the given range. - - Returns: - (Scalar): The scalar number with specified dtype. - """ - return make_tensor([1], device="cpu", dtype=dtype, low=low, high=high).item() - - -def find_nonmatching_dtype(dtype: torch.dtype): - if dtype in int_float_dtypes: - return torch.complex128 - elif dtype in complex_dtypes: - return torch.double - elif dtype is torch.bool: - return torch.float32 - return None - - -def is_complex_dtype(dtype: torch.dtype): - return dtype in complex_dtypes - - -def is_floating_dtype(dtype: torch.dtype): - return dtype in float_dtypes - - -def is_integer_dtype(dtype: torch.dtype): - return dtype in int_dtypes - - -def is_tensor(a): - return isinstance(a, torch.Tensor) diff --git a/tests/python/pytest_ops.py b/tests/python/test_ops.py similarity index 97% rename from tests/python/pytest_ops.py rename to tests/python/test_ops.py index 8de3316e28c..ded5e946bde 100644 --- a/tests/python/pytest_ops.py +++ b/tests/python/test_ops.py @@ -9,11 +9,11 @@ from copy import deepcopy from benchmarks.python.core import clear_cuda_cache -from pytest_fusion_definitions import default_fd_fn, parse_inputs_fusion_definition -from pytest_framework import create_op_test, atexit_serde_create_op_test -from pytest_core import ReferenceType, OpInfo, SampleInput -from pytest_opinfos import opinfos -from pytest_utils import ArgumentType, is_tensor, requiresJAX +from opinfo_fusion_definitions import default_fd_fn, parse_inputs_fusion_definition +from opinfo_framework import create_op_test, atexit_serde_create_op_test +from opinfo_core import ReferenceType, OpInfo, SampleInput +from opinfos import opinfos +from utils import ArgumentType, is_tensor, requiresJAX from typing import Callable from nvfuser import FusionCache, FusionDefinition diff --git a/tests/python/utils.py b/tests/python/utils.py index 9d130bb8e8d..1d9fbe6c5d4 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -5,16 +5,186 @@ import os from copy import deepcopy -from typing import Callable +from typing import Callable, Optional import tempfile import torch +import pytest +from torch.testing import make_tensor +from functools import wraps +from enum import Enum, auto +from torch.testing._internal.common_utils import TestCase # flake8 complains about DataType being unused in this file but it is necessary # to run captured fusion definition. # flake8: noqa from nvfuser import FusionCache, FusionDefinition, DataType -from torch.testing._internal.common_utils import TestCase +try: + # flake8: noqa + import jax + + JAX_AVAILABLE = True +except ImportError as e: + JAX_AVAILABLE = False + pass + + +def requiresJAX(fn): + @wraps(fn) + def _fn(*args, **kwargs): + if not JAX_AVAILABLE: + pytest.xfail("Requires JAX") + return fn(*args, **kwargs) + + return _fn + + +class ArgumentType(Enum): + # a symbolic value requires an input argument during kernel execution + Symbolic = auto() + # scalar with constant value + ConstantScalar = auto() + # python number - int, float, complex, bool + Constant = auto() + + +bool_dtypes = (torch.bool,) + +int_dtypes = ( + torch.int32, + torch.int64, +) + +half_precision_float_dtypes = ( + torch.bfloat16, + torch.float16, +) + +full_precision_float_dtypes = ( + torch.float32, + torch.float64, +) + +complex_dtypes = ( + torch.complex64, + torch.complex128, +) + +# Half-precision float dtypes bf16, fp16 are skipped because nvfuser upcasts those dtypes to fp32 +# but does not return the original type. +bool_int_dtypes = bool_dtypes + int_dtypes +float_dtypes = half_precision_float_dtypes + full_precision_float_dtypes +int_float_dtypes = int_dtypes + full_precision_float_dtypes +float_complex_dtypes = full_precision_float_dtypes + complex_dtypes +all_dtypes_except_reduced = int_dtypes + full_precision_float_dtypes + complex_dtypes +all_dtypes_except_bool = all_dtypes_except_reduced + half_precision_float_dtypes +all_dtypes = all_dtypes_except_bool + bool_dtypes + +map_dtype_to_str = { + torch.bool: "bool", + torch.uint8: "uint8", + torch.int8: "int8", + torch.int16: "int16", + torch.int32: "int32", + torch.int64: "int64", + torch.bfloat16: "bfloat16", + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.complex64: "complex64", + torch.complex128: "complex128", +} + +torch_to_jax_dtype_map = None +if JAX_AVAILABLE: + import jax.numpy as jnp + + torch_to_jax_dtype_map = { + torch.bool: jnp.bool_, + torch.uint8: jnp.uint8, + torch.int8: jnp.int8, + torch.int16: jnp.int16, + torch.int32: jnp.int32, + torch.int64: jnp.int64, + torch.bfloat16: jnp.bfloat16, + torch.float16: jnp.float16, + torch.float32: jnp.float32, + torch.float64: jnp.float64, + torch.complex64: jnp.complex64, + torch.complex128: jnp.complex128, + } + +torch_to_python_dtype_map = { + torch.bool: bool, + torch.uint8: int, + torch.int8: int, + torch.int16: int, + torch.int32: int, + torch.int64: int, + torch.bfloat16: float, + torch.float16: float, + torch.float32: float, + torch.float64: float, + torch.complex64: complex, + torch.complex128: complex, +} + + +def make_tensor_like(a): + # type: (torch.Tensor) -> torch.Tensor + """Returns a tensor with the same properties as the given tensor. + + Args: + a (torch.Tensor): The tensor to copy properties from. + + Returns: + torch.Tensor: A tensor with the same properties as :attr:`a`. + """ + return torch.testing.make_tensor( + a.shape, device=a.device, dtype=a.dtype, requires_grad=a.requires_grad + ) + + +def make_number( + dtype: torch.dtype, low: Optional[float] = None, high: Optional[float] = None +): + """Returns a random number with desired dtype + + Args: + dtype (torch.dtype): Desired dtype for number. + low (Optional[Number]): Sets the lower limit (inclusive) of the given range. + high (Optional[Number]): Sets the upper limit (exclusive) of the given range. + + Returns: + (Scalar): The scalar number with specified dtype. + """ + return make_tensor([1], device="cpu", dtype=dtype, low=low, high=high).item() + + +def find_nonmatching_dtype(dtype: torch.dtype): + if dtype in int_float_dtypes: + return torch.complex128 + elif dtype in complex_dtypes: + return torch.double + elif dtype is torch.bool: + return torch.float32 + return None + + +def is_complex_dtype(dtype: torch.dtype): + return dtype in complex_dtypes + + +def is_floating_dtype(dtype: torch.dtype): + return dtype in float_dtypes + + +def is_integer_dtype(dtype: torch.dtype): + return dtype in int_dtypes + + +def is_tensor(a): + return isinstance(a, torch.Tensor) def is_pre_volta(): diff --git a/tools/codediff/compare_codegen.sh b/tools/codediff/compare_codegen.sh index c7f7d701db4..f88b9c07ec5 100755 --- a/tools/codediff/compare_codegen.sh +++ b/tools/codediff/compare_codegen.sh @@ -214,7 +214,7 @@ collect_kernels() { # python tests # Using -s to disable capturing stdout. This is important as it will let us see which tests creates each .cu file "${bashcmd[@]}" -o "$pyopsdir" -- \ - python -m pytest "$nvfuserdir/python_tests/pytest_ops.py" -n 0 -v -s --color=yes + python -m pytest "$nvfuserdir/python_tests/test_ops.py" -n 0 -v -s --color=yes "${bashcmd[@]}" -o "$pyschedopsdir" -- \ python -m pytest "$nvfuserdir/python_tests/test_schedule_ops.py" -n 0 -v -s --color=yes "${bashcmd[@]}" -o "$pyfrontenddir" -- \ From a70d7b5df17aa10463ef2be7c5b958d2ec1f5a11 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 26 Aug 2024 14:43:56 -0700 Subject: [PATCH 14/54] Fix a bug in getReproString. (#2848) --- nvfuser/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 5d1aa6ca725..eaa3a76cc97 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -67,7 +67,7 @@ def __exit__(self, type, value, traceback): def getReproString(self, inputs: list | None = None) -> str: msg = "# CUDA devices:\n" for i in range(torch.cuda.device_count()): - msg += f"# {0}: {torch.cuda.get_device_name(i)}\n" + msg += f"# {i}: {torch.cuda.get_device_name(i)}\n" msg += ( f"# torch version: {torch.__version__}\n" f"# cuda version: {torch.version.cuda}\n" From 00a18aae8a77e9c8cbcc8c182d1c7cb09b851150 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 26 Aug 2024 15:42:29 -0700 Subject: [PATCH 15/54] Simplify a test. (#2840) --- tests/cpp/test_resize.cpp | 82 ++++++++------------------------------- 1 file changed, 17 insertions(+), 65 deletions(-) diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index f8b2a2ecb02..3404356db30 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -1906,86 +1906,38 @@ TEST_F(ResizeTest, FusionSliceForNanoGPT2) { } // C++ version of TestNvFuserFrontend.test_nanogpt_split_mha_linears -TEST_F(ResizeTest, FusionSliceForNanoGPT3) { +TEST_F(ResizeTest, SliceForNanoGPT3) { // To verify input caching condition in this test, disable aliasing as that // will skip compilation and no kernel will exist. preseg_passes::OptimizationPassGuard optimization_guard(false); - auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr; - FusionGuard fg(fusion_ptr.get()); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); EnableOptionsGuard opt_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::MemoryPromotion); - std::vector input_shape{16, 128, 3072}; - - auto tv0 = makeSymbolicTensor(3); - - fusion.addInput(tv0); - - auto tv1 = slice( - tv0, - {{IrBuilder::create(0L), IrBuilder::create(16L)}, - {IrBuilder::create(0L), IrBuilder::create(128L)}, - {IrBuilder::create(0L), IrBuilder::create(1024L)}}); - auto tv2 = slice( - tv0, - {{IrBuilder::create(0L), IrBuilder::create(16L)}, - {IrBuilder::create(0L), IrBuilder::create(128L)}, - {IrBuilder::create(1024L), IrBuilder::create(2048L)}}); - auto tv3 = slice( - tv0, - {{IrBuilder::create(0L), IrBuilder::create(16L)}, - {IrBuilder::create(0L), IrBuilder::create(128L)}, - {IrBuilder::create(2048L), IrBuilder::create(3072L)}}); - - auto tv4 = reshape(tv1, {16, 128, 1024}, {16, 128, 16, 64}); - auto tv5 = reshape(tv2, {16, 128, 1024}, {16, 128, 16, 64}); - auto tv6 = reshape(tv3, {16, 128, 1024}, {16, 128, 16, 64}); + auto* in = makeSymbolicTensor(3); + fusion->addInput(in); - // TODO: add permute - fusion.addOutput(tv4); - fusion.addOutput(tv5); - fusion.addOutput(tv6); + std::vector slices = chunk(in, /*chunks=*/3, /*dim=*/-1); + for (auto* slice : slices) { + TensorView* out = reshape(slice, {16, 128, 1024}, {16, 128, 16, 64}); + // TODO: add permute + fusion->addOutput(out); + } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in_tensor = at::randn({16, 128, 3072}, options); - auto t0 = at::randn(input_shape, options); - std::vector aten_inputs({t0}); - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + FusionExecutorCache executor_cache(std::move(fusion)); + auto out_tensors = executor_cache.runFusionWithInputs({in_tensor}); + testValidate( + executor_cache.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); auto runtime = executor_cache.getMostRecentKernelRuntime(); - NVF_CHECK(!runtime->isSegmented(), "Segmentation not expected"); - - auto kernel = runtime->executors().at(0).kernel(); - NVF_CHECK( - !kernel->summary().has_cooperative_grid_reduction, - "Grid sync should not be used as slicing input should avoid input caching"); - - auto at_t1 = t0.index( - {at::indexing::Slice(0, 16), - at::indexing::Slice(0, 128), - at::indexing::Slice(0, 1024)}); - auto at_t2 = t0.index( - {at::indexing::Slice(0, 16), - at::indexing::Slice(0, 128), - at::indexing::Slice(1024, 2048)}); - auto at_t3 = t0.index( - {at::indexing::Slice(0, 16), - at::indexing::Slice(0, 128), - at::indexing::Slice(2048, 3072)}); - - auto at_t4 = at_t1.reshape({16, 128, 16, 64}); - auto at_t5 = at_t2.reshape({16, 128, 16, 64}); - auto at_t6 = at_t3.reshape({16, 128, 16, 64}); - - NVF_CHECK(cg_outputs.at(0).equal(at_t4)); - NVF_CHECK(cg_outputs.at(1).equal(at_t5)); - NVF_CHECK(cg_outputs.at(2).equal(at_t6)); + EXPECT_FALSE(runtime->isSegmented()); } TEST_F(ResizeTest, ResizeReshapeAndSlice) { From b17ca1af1e4611ea5967c28444fadc00800ea1f4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Mon, 26 Aug 2024 19:45:39 -0400 Subject: [PATCH 16/54] Run test_host_ir and test_multidevice in manual_ci.sh (#2845) This small PR just updates `manual_ci.sh` to have more complete coverage for local testing. --- manual_ci.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/manual_ci.sh b/manual_ci.sh index 26e2cc1ce34..025960cc5bb 100755 --- a/manual_ci.sh +++ b/manual_ci.sh @@ -3,14 +3,14 @@ failed_tests=false run_test() { - eval "$1" + eval "$*" status=$? if [ $status -ne 0 ]; then failed_tests=true echo "=============================================================" echo "= test_failed!" - echo "= $1" + echo "= $*" echo "=============================================================" fi } @@ -21,7 +21,11 @@ run_test './bin/lib/dynamic_type/test_dynamic_type_17' run_test './bin/lib/dynamic_type/test_dynamic_type_20' run_test './bin/nvfuser_tests' run_test './bin/test_rng' -# run_test './bin/test_multidevice' +run_test './bin/test_host_ir' +if type -p mpirun > /dev/null +then + run_test mpirun -np 1 './bin/test_multidevice' +fi run_test './bin/test_view' run_test './bin/test_matmul' run_test './bin/test_external_src' From 08db8fa23845c6e310d81660770669268d8b3d73 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 26 Aug 2024 17:33:49 -0700 Subject: [PATCH 17/54] Adds a Graphviz graph printer for ValGraph (#2849) Also adds NVFUSER_DUMP=indexing_verbose option This is just a utility for debugging and experimentation. There's no functional change. --- csrc/id_model/indexing.cpp | 8 ++++ csrc/options.cpp | 1 + csrc/options.h | 1 + csrc/val_graph.cpp | 79 ++++++++++++++++++++++++++++++++++++++ csrc/val_graph.h | 2 + 5 files changed, 91 insertions(+) diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index db8c5ec2053..23f2def403e 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -753,6 +753,14 @@ ParallelType getParallelType(const ValGroup& loop_group) { TensorIndexer::TensorIndexer(IdModel& id_model) : id_model_(id_model) { buildLoopIndexMap(); + + if (isDebugDumpEnabled(DebugDumpOption::IndexingVerbose)) { + std::ofstream ofs("indexing_traversal_graph.dot", std::ofstream::trunc); + auto dot_string = + id_model_.idGraph(IdMappingMode::ALMOSTEXACT).toGraphvizDotGraph(); + ofs << dot_string; + ofs.close(); + } } void TensorIndexer::buildLoopIndexMap() { diff --git a/csrc/options.cpp b/csrc/options.cpp index c2c9a7d8a25..8618384ccf1 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -122,6 +122,7 @@ std::unordered_map> Options< {"global_zeroed_memory", DebugDumpOption::GlobalZeroedMemory}, {"host_ir", DebugDumpOption::HostIr}, {"index_type", DebugDumpOption::IndexType}, + {"indexing_verbose", DebugDumpOption::IndexingVerbose}, {"kernel_args", DebugDumpOption::KernelArgs}, {"kernel_ir", DebugDumpOption::KernelIr}, {"launch_param", DebugDumpOption::LaunchParam}, diff --git a/csrc/options.h b/csrc/options.h index 073530c0673..2d8a48a0ec0 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -82,6 +82,7 @@ enum class DebugDumpOption { Occupancy, // Dump occupancy IndexType, //! Print the index type of the launched kernel PredicateElimination, //! Print the predicate elimination information + IndexingVerbose, //! Print verbose debug info on indexing EndOfOption //! Placeholder for counting the number of elements }; diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index 39e122fdd3e..845b4327f96 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -698,4 +698,83 @@ std::optional hasSelfMapping( return std::nullopt; } +std::string ValGraph::toGraphvizDotGraph() const { + std::stringstream dot; + + dot << "digraph ValGraph {\n"; + + const std::string indent = " "; + + // Use the pointer value as the name and attach a label with the + // val names + std::unordered_map val_names; + for (const auto& val_group : disjointValSets().disjointSets()) { + std::stringstream name; + name << "val_" << val_group.get(); + val_names.emplace(val_group, name.str()); + } + + std::unordered_map expr_names; + for (const auto& group : disjointExprSets().disjointSets()) { + std::stringstream name; + name << "expr_" << group.get(); + expr_names.emplace(group, name.str()); + } + + auto getGroupLabel = [](const auto& group) -> std::string { + std::set names; + for (const auto val : *group) { + names.insert(val->name()); + } + std::stringstream ss; + const int line_limit = 5; + int wrap_counter = 0; + bool first_name = true; + for (const auto& name : names) { + if (wrap_counter == line_limit) { + ss << "\n"; + wrap_counter = 0; + } else if (!first_name) { + ss << " "; + } + ss << name; + first_name = false; + ++wrap_counter; + } + return ss.str(); + }; + + for (const auto& val_group : disjointValSets().disjointSets()) { + dot << indent << val_names.at(val_group) + << " [label=\"V: " << getGroupLabel(val_group) << "\"];\n"; + } + + for (const auto& expr_group : disjointExprSets().disjointSets()) { + dot << indent << expr_names.at(expr_group) + << " [label=\"E: " << getGroupLabel(expr_group) << "\"];\n"; + } + + for (const auto& val_group : disjointValSets().disjointSets()) { + dot << indent << "// Definitions of " << nvfuser::toString(val_group) + << "\n"; + for (const auto& def : getDefinitions(val_group)) { + dot << indent << expr_names.at(def) << " -> " << val_names.at(val_group) + << "\n"; + } + + dot << indent << "// Uses of " << nvfuser::toString(val_group) << "\n"; + + for (const auto& use : getUses(val_group)) { + dot << indent << val_names.at(val_group) << " -> " << expr_names.at(use) + << "\n"; + } + + dot << "\n"; + } + + dot << "}\n"; + + return dot.str(); +} + } // namespace nvfuser diff --git a/csrc/val_graph.h b/csrc/val_graph.h index e9d9f0d5fd1..d8aa9219bb6 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -185,6 +185,8 @@ class ValGraph { std::string toString() const; + std::string toGraphvizDotGraph() const; + // Initializes entries for the provided Val with its definitions and // uses. The provided Val will have its own new ValGroup, each item in the // definitions and uses will become a new ExprGroup, and these new ExprGroups From de1672c653094062d46125ed622d2b781a50aa77 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Mon, 26 Aug 2024 19:57:32 -0700 Subject: [PATCH 18/54] Add support for `torch.minimum` and `torch.maximum` operations (#2847) This PR adds the `minimum` and `maximum` operators to nvfuser. - https://pytorch.org/docs/stable/generated/torch.minimum.html - https://pytorch.org/docs/stable/generated/torch.maximum.html --- csrc/ops/arith.cpp | 2 ++ csrc/ops/arith.h | 10 ++++++++++ csrc/python_frontend/python_bindings.cpp | 2 ++ csrc/serde/fusion_record.cpp | 2 ++ tests/python/opinfos.py | 18 ++++++++++++++++++ 5 files changed, 34 insertions(+) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 0101057a8fb..399b128672f 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -945,6 +945,8 @@ NVFUSER_DEFINE_BINARY_CAST_OP(mul, Mul) NVFUSER_DEFINE_BINARY_CAST_OP(pow, Pow) NVFUSER_DEFINE_BINARY_CAST_OP(remainder, Remainder) NVFUSER_DEFINE_BINARY_CAST_OP(sub, Sub) +NVFUSER_DEFINE_BINARY_CAST_OP(minimum, Min) +NVFUSER_DEFINE_BINARY_CAST_OP(maximum, Max) #undef NVFUSER_DEFINE_BINARY_CAST_OP #define NVFUSER_DEFINE_LOGICAL_OP(op_name, op_type) \ diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 918a4372795..b84d9c03c5a 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -489,6 +489,16 @@ NVF_API Val* sub(Val* v1, Val* v2); NVF_API TensorView* sub(TensorView* v1, Val* v2); NVF_API TensorView* sub(Val* v1, TensorView* v2); NVF_API TensorView* sub(TensorView* v1, TensorView* v2); +// maximum +NVF_API Val* maximum(Val* v1, Val* v2); +NVF_API TensorView* maximum(TensorView* v1, Val* v2); +NVF_API TensorView* maximum(Val* v1, TensorView* v2); +NVF_API TensorView* maximum(TensorView* v1, TensorView* v2); +// minimum +NVF_API Val* minimum(Val* v1, Val* v2); +NVF_API TensorView* minimum(TensorView* v1, Val* v2); +NVF_API TensorView* minimum(Val* v1, TensorView* v2); +NVF_API TensorView* minimum(TensorView* v1, TensorView* v2); // nextafter: Only single- or double-precision // floating point types (after promotion) are supported. NVF_API Val* nextafter(Val* v1, Val* v2); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index eda75ec83c3..79e7047a54f 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1450,6 +1450,8 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_BINARY_OP("pow", pow) NVFUSER_PYTHON_BINDING_BINARY_OP("remainder", remainder) NVFUSER_PYTHON_BINDING_BINARY_OP("sub", sub) + NVFUSER_PYTHON_BINDING_BINARY_OP("minimum", minimum) + NVFUSER_PYTHON_BINDING_BINARY_OP("maximum", maximum) NVFUSER_PYTHON_BINDING_BINARY_OP("mod", mod) NVFUSER_PYTHON_BINDING_BINARY_OP("eq", eq) NVFUSER_PYTHON_BINDING_BINARY_OP("ge", ge) diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index ab51df611fc..f55cdedbb26 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -800,6 +800,8 @@ void RecordFunctorFactory::setupFunctionMaps() { NVFUSER_BINARY_TV_OP("pow", pow) NVFUSER_BINARY_TV_OP("remainder", remainder) NVFUSER_BINARY_TV_OP("sub", sub) + NVFUSER_BINARY_TV_OP("minimum", minimum) + NVFUSER_BINARY_TV_OP("maximum", maximum) NVFUSER_BINARY_TV_OP("mod", mod) NVFUSER_BINARY_TV_OP("eq", eq) NVFUSER_BINARY_TV_OP("ge", ge) diff --git a/tests/python/opinfos.py b/tests/python/opinfos.py index e711a6ceae0..e2d42edf7d7 100644 --- a/tests/python/opinfos.py +++ b/tests/python/opinfos.py @@ -661,6 +661,24 @@ ) binary_ops.append(lt_opinfo) +minimum_opinfo = OpInfo( + lambda fd: fd.ops.minimum, + "minimum", + dtypes=int_float_dtypes, + sample_input_generator=elementwise_binary_generator, + reference=_elementwise_binary_torch(torch.minimum), +) +binary_ops.append(minimum_opinfo) + +maximum_opinfo = OpInfo( + lambda fd: fd.ops.maximum, + "maximum", + dtypes=int_float_dtypes, + sample_input_generator=elementwise_binary_generator, + reference=_elementwise_binary_torch(torch.maximum), +) +binary_ops.append(maximum_opinfo) + mod_opinfo = OpInfo( lambda fd: fd.ops.mod, "mod", From f430969812173b997efaf2b60ce630419504aac9 Mon Sep 17 00:00:00 2001 From: Protonu Date: Tue, 27 Aug 2024 10:57:16 -0400 Subject: [PATCH 19/54] [stmatrix] Store a 8x8 matrix with hardcoded indices (#2822) This is an initial implementation of [stmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=stmatrix#warp-level-matrix-store-instruction-stmatrix), where we test it on a 8x8 matrix. We do not schedule the consumer of stmatrix as yet - instead we hardcode indices. The generated code is [here.](https://gist.github.com/protonu/7599c6ef4a3f41f9ab5de8332788fe41) --- csrc/codegen.cpp | 1 + .../analysis/sync_information.cpp | 1 + csrc/device_lower/pass/index.cpp | 37 +++++++++++--- csrc/device_lower/pass/inline_ptx.cpp | 14 ++++++ csrc/device_lower/utils.cpp | 7 +++ csrc/device_lower/utils.h | 2 + csrc/type.cpp | 2 + csrc/type.h | 3 +- tests/cpp/test_memory.cpp | 48 +++++++++++++++++++ 9 files changed, 108 insertions(+), 7 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index b99c6d97255..5a23c47dc34 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1246,6 +1246,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { auto optype = ldst->opType(); NVF_ERROR( optype != LoadStoreOpType::LdMatrix && + optype != LoadStoreOpType::StMatrix && optype != LoadStoreOpType::CpAsync, "ldmatrix and cp.async should be lowered as kir::Asm"); diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 1d132619b1b..e1a3f5324b1 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -712,6 +712,7 @@ SyncMap::SyncMap(Fusion* fusion) { } else if (raw_dims.hasTID()) { NVF_ERROR( ir_utils::isLdMatrixOp(producer->definition()) || + ir_utils::isStMatrixOp(consumer->definition()) || producer->getMemoryType() == MemoryType::Global || producer->getMemoryType() == MemoryType::Shared, "Inconsistent parallelization found between TV", diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 19b8aea782a..da31de9020d 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1507,16 +1507,41 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { std::make_shared(DataType::UInt32), (size_t)ir_utils::getVectorizeSize(ldst->out()->as()) / 2}; + } else if (ir_utils::isStMatrixOp(ldst)) { + as_type = ArrayType{ + std::make_shared(DataType::UInt32), + 1 /*hard coded for 8*8 store*/}; } else if (ldst->out()->definition()->isA()) { // For MMA accumulator initialization as_type = getMmaOutType(ldst->out()->as()); } - in = lowerSrcIndex( - ldst->in(), - ldst->out(), - {}, - ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst)); - out = lowerDstIndex(ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type); + + if (ir_utils::isStMatrixOp(ldst)) { + // Currently we create hard coded indexing for stmatrix which works on 8x8 + // matrices. T_local[0] + in = IrBuilder::create( + dynamic_cast(ldst->in()), + IrBuilder::create(0, DataType::Index), + as_type); + + // T_shared[toSmem(T_shared) + 16 * tidx.x] + auto out_index = IrBuilder::addExpr( + IrBuilder::baseAddressExpr(dynamic_cast(ldst->out())), + IrBuilder::mulExpr( + IrBuilder::create(16, DataType::Index), + IrBuilder::create("threadIdx.x", DataType::Index))); + + out = IrBuilder::create( + dynamic_cast(ldst->out()), out_index); + } else { + in = lowerSrcIndex( + ldst->in(), + ldst->out(), + {}, + ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst)); + out = + lowerDstIndex(ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type); + } auto new_ldst = IrBuilder::create(ldst->opType(), out, in, ldst->cacheOp()) ->withPredicate(ldst->predicate()); diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 4b7c0a67a4e..606a6d845e0 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -72,6 +72,20 @@ class LowerToInlinePtx : public kir::ExprMutator { std::vector{ldst->in()}, kir::Asm::Options{/*volatile=*/true})); return; + } else if (ir_utils::isStMatrixOp(ldst)) { + std::stringstream ss; + ss << "stmatrix.sync.aligned.x" + << std::get(ldst->in()->dtype().type).size; + ss << ".m8n8.shared.b16"; + registerReplace( + ldst, + // stmatrix has no output. + IrBuilder::create( + ss.str(), + std::vector{}, + std::vector{ldst->out(), ldst->in()}, + kir::Asm::Options{/*volatile=*/true})); + return; } else if (ir_utils::isCpAsyncOp(ldst)) { auto out_tv = ldst->out()->as()->view(); auto vec_size = diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index a75f95c4ab0..667fb7c66f1 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -180,6 +180,13 @@ bool isLdMatrixOp(const Expr* expr) { return false; } +bool isStMatrixOp(const Expr* expr) { + if (auto ldst = dynamic_cast(expr)) { + return ldst->opType() == LoadStoreOpType::StMatrix; + } + return false; +} + bool isCpAsyncOp(const Expr* expr) { if (auto ldst = dynamic_cast(expr)) { return ldst->opType() == LoadStoreOpType::CpAsync; diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index fa192dca180..061786aa1e3 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -128,6 +128,8 @@ std::unordered_map getParallelDomains( //! a ldmatrix intrinsic. bool isLdMatrixOp(const Expr* expr); +bool isStMatrixOp(const Expr* expr); + //! Returns true if the expression will be lowered to //! a cp.async intrinsic. bool isCpAsyncOp(const Expr* expr); diff --git a/csrc/type.cpp b/csrc/type.cpp index b022f7ee087..7d92ad05dab 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -836,6 +836,8 @@ const char* load_store_type2string(LoadStoreOpType t) { return "Set"; case LoadStoreOpType::LdMatrix: return "LdMatrix"; + case LoadStoreOpType::StMatrix: + return "StMatrix"; case LoadStoreOpType::CpAsync: return "CpAsync"; case LoadStoreOpType::CpAsyncBulkTensorTile: diff --git a/csrc/type.h b/csrc/type.h index f78ebed0531..d3adcd6092a 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -753,7 +753,8 @@ enum class LoadStoreOpType { SegmenterSet, LdMatrix, CpAsync, - CpAsyncBulkTensorTile + CpAsyncBulkTensorTile, + StMatrix }; // Used to label what part of the circular buffered iterdomain diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 5811e453b2a..388169d3be5 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -2504,6 +2504,54 @@ TEST_P(LdMatrixTest, Regular) { testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); } +class StMatrixTest : public NVFuserTest { + protected: + void SetUp() override { + if (cudaArchGuardShouldSkip(9, 0)) { + GTEST_SKIP() << "skipping tests on pre-Hopper GPUs"; + } + NVFuserTest::SetUp(); + } +}; + +TEST_F(StMatrixTest, Regular) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto operand = MmaOperand::B; + + int sizeM = 8; + int sizeN = 8; + + auto tv0 = makeContigConcreteTensor({sizeM, sizeN}, DataType::Half); + fusion.addInput(tv0); + // tv0 (global) -> tv1 (shared) + auto tv1 = set(tv0); + tv1->setMemoryType(MemoryType::Shared); + auto tv2 = set(tv1); + // tv1 (shared) -> tv2 (registers) + // tv2 (registers) -> tv3 (shared) + auto tv3 = set(tv2); + tv3->definition()->as()->setOpType(LoadStoreOpType::StMatrix); + tv3->setMemoryType(MemoryType::Shared); + // tv3 (shared) -> tv4(global) + auto tv4 = set(tv3); + fusion.addOutput(tv4); + + tv2->applyMmaSwizzle(operand); + tv2->setAllocationDomain(tv2->getLoopDomain(), true); + // We do not schedule tv3 as yet. + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({sizeM, sizeN}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}, LaunchParams(), matmul_cparams); + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); +} + TEST_P(LdMatrixTest, Transpose) { Fusion fusion; FusionGuard fg(&fusion); From 742f24c9c00c8a25d631751e7c8a05a2fa675e55 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 27 Aug 2024 08:59:48 -0700 Subject: [PATCH 20/54] Refactor CircularBufferLoopCloner (#2823) This PR refactors `CircularBufferLoopCloner` to avoid clang-tidy issues in https://github.com/NVIDIA/Fuser/pull/2773. - Track cloned for loop instead of its Scope - Add virtual methods `processExpr` and `processForLoop` for `TmaCircularBufferLoopCloner` to override. Details: ``` Error (CLANGTIDY) [bugprone-parent-virtual-call,-warnings-as-errors] qualified name 'kir::IrVisitor::dispatch' refers to a member overridden in subclass; did you mean 'nvfuser::CircularBufferLoopCloner'? ``` --- csrc/device_lower/pass/circular_buffer.cpp | 39 ++++++++++++++-------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index a0d38674095..d4cc53bbd71 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -118,8 +118,8 @@ class CircularBufferLoopCloner : public kir::IrVisitor { start, stop, /*step=*/GpuLower::current()->kernel()->oneVal(), - /*step=*/false, - /*vectorize=*/nullptr, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, circular_buffer_loop_->isUnrollRequired(), loop_type_); @@ -131,16 +131,24 @@ class CircularBufferLoopCloner : public kir::IrVisitor { ? cloned_top_level_loop_ : IrBuilder::create(fl); - cloned_scopes_.push_back(&cloned_loop->body()); + // Add to stack + for_loop_stack_.push_back(cloned_loop); + // Process for-loop kir::IrVisitor::handle(fl); - cloned_scopes_.pop_back(); + // Pop from stack + for_loop_stack_.pop_back(); + // Specific handling of for-loop + processForLoop(cloned_loop); + } + + virtual void processForLoop(ForLoop* cloned_loop) { // Add the cloned loop into the parent loop body only when the // cloned loop contains expressions. - if (!cloned_loop->body().empty() && !cloned_scopes_.empty()) { - cloned_scopes_.back()->push_back(cloned_loop); + if (!cloned_loop->body().empty() && !for_loop_stack_.empty()) { + for_loop_stack_.back()->body().push_back(cloned_loop); } } @@ -149,7 +157,7 @@ class CircularBufferLoopCloner : public kir::IrVisitor { } void dispatch(Expr* expr) override { - // skip expression if it is in exclude set + // Skip expression if it is in exclude set if (exclude_.count(expr) > 0) { return; } @@ -160,8 +168,13 @@ class CircularBufferLoopCloner : public kir::IrVisitor { return; } - NVF_ERROR(!cloned_scopes_.empty()); + NVF_ERROR(!for_loop_stack_.empty()); + + // Specific expression handling + processExpr(expr); + } + virtual void processExpr(Expr* expr) { switch (loop_type_) { case CircularBufferLoopStage::Prolog: { // In Prologue, only copy the load expressions. @@ -169,19 +182,19 @@ class CircularBufferLoopCloner : public kir::IrVisitor { // circular buffered TVs (e.g., buffer initialization). TensorView* out_tv = ir_utils::getTvOutput(expr); if (circular_buffer_load_tvs_.count(out_tv) > 0) { - cloned_scopes_.back()->push_back(expr); + for_loop_stack_.back()->body().push_back(expr); } break; } case CircularBufferLoopStage::Main: { - cloned_scopes_.back()->push_back(expr); + for_loop_stack_.back()->body().push_back(expr); break; } case CircularBufferLoopStage::Epilog: { // In Epilogue, copy everything except circular buffer load expressions. TensorView* out_tv = ir_utils::getTvOutput(expr); if (circular_buffer_load_tvs_.count(out_tv) == 0) { - cloned_scopes_.back()->push_back(expr); + for_loop_stack_.back()->body().push_back(expr); } break; } @@ -191,14 +204,14 @@ class CircularBufferLoopCloner : public kir::IrVisitor { } } - private: + protected: ForLoop* circular_buffer_loop_ = nullptr; const std::vector& circular_buffer_load_exprs_; const CircularBufferLoopStage loop_type_; std::unordered_set circular_buffer_load_tvs_; ForLoop* cloned_top_level_loop_ = nullptr; - std::deque cloned_scopes_; + std::vector for_loop_stack_; const std::unordered_set& exclude_; }; From cf141ed738b77b08e7c3bfe5e8b1311873c725e5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 27 Aug 2024 09:50:31 -0700 Subject: [PATCH 21/54] Clean up PadOp index lowering (#2852) A small cleanup of `IndexLowering::handle(PadOp*)`. Extended the pad unit tests to run both the legacy and new indexers. --- csrc/device_lower/pass/index.cpp | 82 +++--------- csrc/device_lower/pass/index.h | 7 - csrc/id_model/indexing.cpp | 3 - tests/cpp/test_resize.cpp | 216 +++++++++++++++++++++++++++---- 4 files changed, 204 insertions(+), 104 deletions(-) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index da31de9020d..a833bfafe69 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1855,50 +1855,10 @@ void IndexLowering::allocateUniqueFusedReduction( insertAtTopLevel(fused_reduction_alloc_reduction); } -// This is mostly copied from Index::getProducerPerDimLogicalIndex() -Val* IndexLowering::getIterationIndexForBroadcast( - TensorView* producer_tv, - TensorView* consumer_tv, - IterDomain* broadcast_id) const { - NVF_ERROR( - broadcast_id->isBroadcast(), - "Expected broadcast ID but found ", - broadcast_id->toString()); - - auto c2p_logical_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv) - .mapBroadcast(false) - .mapConsumerToProducer(); - - // This replay has to be consistent with compute at index map. - BestEffortReplay replay_producer_as_consumer( - producer_tv->getLoopDomain(), - consumer_tv->getLoopDomain(), - c2p_logical_map); - - const auto& c2p_map = replay_producer_as_consumer.getReplay(); - const auto& producer_indexing_from_idgraph = getTensorIndexFromIdGraph( - for_loops_, getRotatedLoop(), consumer_tv, producer_tv, true, c2p_map); - - const auto& producer_indexing = producer_indexing_from_idgraph.index; - - const auto& index_map = producer_indexing.indexMap(); - const auto index_it = index_map.find(broadcast_id); - NVF_ERROR( - index_it != index_map.end(), - "Could not find padded consumer IterDomain ", - broadcast_id->toString(), - " from consumer TensorView ", - consumer_tv->toString(), - " in index map for producer TensorView ", - producer_tv->toString()); - - return index_it->second; -} - void IndexLowering::handle(const PadOp* pad) { // Convert to a where op as: - // consumer[consumer_idx] = (producer_idx >= 0 && producer_idx < - // producer_extent) ? + // consumer[consumer_idx] = (consumer_idx >= left_pad && consumer_idx < + // consumer_extent - right_pad) ? // producer[producer_idx] : // 0; @@ -1912,37 +1872,27 @@ void IndexLowering::handle(const PadOp* pad) { const auto pad_val = pad->value(); - std::unordered_map override_index; - for (auto padded_axis : pad->getPaddedAxes()) { - auto padded_id = producer_doms.at(padded_axis); - if (padded_id->isBroadcast()) { - // When we pad a Broadcast IterDomain, we should not treat it as a - // Broadcast as we normally would. Instead, we will treat it as a regular - // Iteration domain with extent 1. - auto ind = - getIterationIndexForBroadcast(producer_tv, consumer_tv, padded_id); - override_index.emplace(padded_id, ind); - } - } - - const auto producer_root_indices = Index::getProducerPerDimLogicalIndex( - producer_tv, consumer_tv, for_loops_, getRotatedLoop(), override_index); - // Build a predicate for where - Val* pred = IrBuilder::create(true); + auto consumer_root_indices = Index::getConsumerPerDimLogicalIndex( + consumer_tv, for_loops_, getRotatedLoop()); + Val* pred = consumer_tv->fusion()->trueVal(); for (auto padded_axis : pad->getPaddedAxes()) { - auto producer_idx = producer_root_indices.at(padded_axis); - auto producer_root_id = producer_doms.at(padded_axis); - NVF_ERROR(!producer_root_id->maybePartial()); + auto consumer_idx = consumer_root_indices.at(padded_axis); + auto consumer_root_id = consumer_tv->getLogicalDomain().at(padded_axis); + NVF_ERROR(!consumer_root_id->maybePartial()); + const auto& pad_widths = pad->getPadWidths(padded_axis); pred = SimplifyingIrBuilder::logicalAndExpr( pred, - // idx >= 0 && idx < extent + // idx >= left_pad && idx < extent - right_pad SimplifyingIrBuilder::logicalAndExpr( - SimplifyingIrBuilder::geExpr( - producer_idx, GpuLower::current()->kernel()->zeroVal()), + SimplifyingIrBuilder::geExpr(consumer_idx, pad_widths.first), SimplifyingIrBuilder::ltExpr( - producer_idx, producer_root_id->getMaybeExpandedExtent()))); + consumer_idx, + SimplifyingIrBuilder::subExpr( + consumer_root_id->getMaybeExpandedExtent(), + pad_widths.second)))); } + pred = GpuLower::current()->commonScalarMap().hoistScalar(pred, for_loops_); pushBack(IrBuilder::create( diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 448933d3835..a9e07bdfed1 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -162,13 +162,6 @@ class IndexLowering : private OptOutConstDispatch { // fused reduction. void allocateUniqueFusedReduction(Expr* expr, TensorView* out_tv); - //! Get index of producer_tv as if broadcast_id had Iteration type instead of - //! Broadcast - Val* getIterationIndexForBroadcast( - TensorView* producer_tv, - TensorView* consumer_tv, - IterDomain* broadcast_id) const; - private: std::vector lowered_exprs_; diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index 23f2def403e..2308e576ac9 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -1131,9 +1131,6 @@ std::vector TensorIndexer::getPredicates( /*is_start_predicate=*/false, /*unswitched_loop=*/unswitched_loop); - const std::vector non_divisible_split_predicates = - getNonDivisibleConsumerDomainsToPredicate(tv); - const CircularBufferLoopStage loop_stage = getCircularBufferLoopStage( tv, for_loops, id_model_.idGraph(IdMappingMode::LOOP)); diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 3404356db30..b74348f15d9 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -23,7 +23,7 @@ namespace nvfuser { -using ResizeTest = NVFuserTest; +using ResizeTest = NVFuserFixtureParamTest; using testing::Each; using testing::HasSubstr; @@ -32,8 +32,14 @@ using testing::Property; using testing::ThrowsMessage; using testing::UnorderedElementsAre; +INSTANTIATE_TEST_SUITE_P( + , + ResizeTest, + testing::Bool(), + testing::PrintToStringParamName()); + // Simple pad test -TEST_F(ResizeTest, Pad1) { +TEST_P(ResizeTest, Pad1) { Fusion fusion; FusionGuard fg(&fusion); @@ -50,6 +56,13 @@ TEST_F(ResizeTest, Pad1) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -60,7 +73,7 @@ TEST_F(ResizeTest, Pad1) { } // pad + split -TEST_F(ResizeTest, Pad2) { +TEST_P(ResizeTest, Pad2) { Fusion fusion; FusionGuard fg(&fusion); @@ -79,6 +92,13 @@ TEST_F(ResizeTest, Pad2) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -89,7 +109,7 @@ TEST_F(ResizeTest, Pad2) { } // pad, merge + split, inlineMost -TEST_F(ResizeTest, Pad3) { +TEST_P(ResizeTest, Pad3) { Fusion fusion; FusionGuard fg(&fusion); @@ -125,6 +145,13 @@ TEST_F(ResizeTest, Pad3) { auto t1 = at::randn(padded_shape, options); std::vector aten_inputs({t0, t1}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -133,7 +160,7 @@ TEST_F(ResizeTest, Pad3) { } // pad + parallelization -TEST_F(ResizeTest, Pad4) { +TEST_P(ResizeTest, Pad4) { Fusion fusion; FusionGuard fg(&fusion); @@ -152,6 +179,13 @@ TEST_F(ResizeTest, Pad4) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -162,7 +196,7 @@ TEST_F(ResizeTest, Pad4) { } // pad + parallelization + RAW sync -TEST_F(ResizeTest, Pad5) { +TEST_P(ResizeTest, Pad5) { Fusion fusion; FusionGuard fg(&fusion); @@ -200,6 +234,13 @@ TEST_F(ResizeTest, Pad5) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -210,7 +251,7 @@ TEST_F(ResizeTest, Pad5) { } // pad + merge + split parallelization -TEST_F(ResizeTest, Pad6) { +TEST_P(ResizeTest, Pad6) { Fusion fusion; FusionGuard fg(&fusion); @@ -244,6 +285,13 @@ TEST_F(ResizeTest, Pad6) { auto t1 = at::randn(padded_shape, options); std::vector aten_inputs({t0, t1}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -253,7 +301,7 @@ TEST_F(ResizeTest, Pad6) { // pad + unswitch. Having different extents in an unswitched loop nest // needs a special care (see UnrollPass::canOmitElseClause) -TEST_F(ResizeTest, Pad7) { +TEST_P(ResizeTest, Pad7) { Fusion fusion; FusionGuard fg(&fusion); @@ -288,6 +336,13 @@ TEST_F(ResizeTest, Pad7) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -343,7 +398,7 @@ TEST_F(ResizeTest, Pad8) { } #endif -TEST_F(ResizeTest, PadScheduler1) { +TEST_P(ResizeTest, PadScheduler1) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -360,6 +415,13 @@ TEST_F(ResizeTest, PadScheduler1) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -368,7 +430,7 @@ TEST_F(ResizeTest, PadScheduler1) { NVF_CHECK(ref.equal(cg_outputs[0])); } -TEST_F(ResizeTest, PadScheduler2) { +TEST_P(ResizeTest, PadScheduler2) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -392,6 +454,13 @@ TEST_F(ResizeTest, PadScheduler2) { auto t1 = at::randn(padded_shape, options); std::vector aten_inputs({t0, t1}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -439,7 +508,7 @@ TEST_F(ResizeTest, PadScheduler3) { // Two pad exprs, both using the same symbolic pad widths, segmented // into two kernels. Make sure the symbolic inputs are available to // both of the segmented kernels. -TEST_F(ResizeTest, PadScheduler4) { +TEST_P(ResizeTest, PadScheduler4) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -467,6 +536,13 @@ TEST_F(ResizeTest, PadScheduler4) { std::vector pad_extents{1, 1}; std::vector aten_inputs({t0, 1, 1}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -476,7 +552,7 @@ TEST_F(ResizeTest, PadScheduler4) { // Pad a broadcast // See https://github.com/NVIDIA/Fuser/issues/798 -TEST_F(ResizeTest, PadBroadcastInput) { +TEST_P(ResizeTest, PadBroadcastInput) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -500,6 +576,13 @@ TEST_F(ResizeTest, PadBroadcastInput) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -1146,7 +1229,7 @@ std::vector> slice_cases( {-13, -11}}); // Test slice with a variety of constant ranges -TEST_F(NVFuserTest, SliceConstantShmoo_CUDA) { +TEST_F(ResizeTest, SliceConstantShmoo) { for (auto [start, stop] : slice_cases) { Fusion fusion; FusionGuard fg(&fusion); @@ -1175,7 +1258,7 @@ TEST_F(NVFuserTest, SliceConstantShmoo_CUDA) { } // Test slice with a variety of non-constant input ranges -TEST_F(NVFuserTest, SliceInputShmoo_CUDA) { +TEST_F(ResizeTest, SliceInputShmoo) { Fusion fusion; FusionGuard fg(&fusion); @@ -1223,9 +1306,9 @@ TEST_F(NVFuserTest, SliceInputShmoo_CUDA) { } } -// Same as SliceInputShmoo_CUDA but use FusionExecutorCache, which +// Same as SliceInputShmoo but use FusionExecutorCache, which // might re-concretize when output sizes change -TEST_F(NVFuserTest, SliceInputShmooFusionExecutorCache_CUDA) { +TEST_F(ResizeTest, SliceInputShmooFusionExecutorCache) { auto fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -1321,7 +1404,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { << "Unexpected resize output extent: " << resize_extent->toInlineString(); } -TEST_F(ResizeTest, PadReduceScheduler1) { +TEST_P(ResizeTest, PadReduceScheduler1) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -1355,6 +1438,13 @@ TEST_F(ResizeTest, PadReduceScheduler1) { std::back_inserter(aten_inputs), [](auto pad_extent) { return pad_extent; }); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -1638,7 +1728,7 @@ TEST_F(ResizeTest, SoftmaxSliceScheduler2) { } // Same as Pad1 but pad by specified value -TEST_F(ResizeTest, PadWithValue) { +TEST_P(ResizeTest, PadWithValue) { Fusion fusion; FusionGuard fg(&fusion); @@ -1658,6 +1748,13 @@ TEST_F(ResizeTest, PadWithValue) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -1668,7 +1765,7 @@ TEST_F(ResizeTest, PadWithValue) { } // Same as Pad1 but pad by negative value to create an empty tensor -TEST_F(ResizeTest, PadToEmptyTensor) { +TEST_P(ResizeTest, PadToEmptyTensor) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1690,6 +1787,13 @@ TEST_F(ResizeTest, PadToEmptyTensor) { auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -1699,7 +1803,7 @@ TEST_F(ResizeTest, PadToEmptyTensor) { } // Test that padding Half tensor by Double does not promote output -TEST_F(ResizeTest, PadHalfWithDoubleValue) { +TEST_P(ResizeTest, PadHalfWithDoubleValue) { Fusion fusion; FusionGuard fg(&fusion); @@ -1719,6 +1823,13 @@ TEST_F(ResizeTest, PadHalfWithDoubleValue) { auto t0 = at::ones(shape, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -2281,7 +2392,7 @@ TEST_F(ResizeTest, SliceVectorization) { // Concretize a symbolic pad that results in a broadcast (static pads) // In this test, the sizes and pad widths are static, so there should be nothing // to concretize. -TEST_F(NVFuserTest, ResizePadToBroadcastStatic_CUDA) { +TEST_P(ResizeTest, ResizePadToBroadcastStatic) { std::vector t0_size = {2, 3, 2, 5, 6}; std::vector t1_size = {2, 4, 4, 3, 5}; // Note there are only 8 input scalars for 5D input. Implicit no-pad of dim 0 @@ -2330,6 +2441,13 @@ TEST_F(NVFuserTest, ResizePadToBroadcastStatic_CUDA) { auto t1 = at::randn(t1_size, options); std::vector aten_inputs({t0, t1}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -2348,7 +2466,7 @@ TEST_F(NVFuserTest, ResizePadToBroadcastStatic_CUDA) { } // Concretize a symbolic pad that results in a broadcast (dynamic pads) -TEST_F(NVFuserTest, ResizePadToBroadcastDynamic_CUDA) { +TEST_P(ResizeTest, ResizePadToBroadcastDynamic) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -2395,6 +2513,13 @@ TEST_F(NVFuserTest, ResizePadToBroadcastDynamic_CUDA) { }); aten_inputs.insert(aten_inputs.end(), pad_widths.begin(), pad_widths.end()); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -2415,7 +2540,7 @@ TEST_F(NVFuserTest, ResizePadToBroadcastDynamic_CUDA) { } // See https://github.com/NVIDIA/Fuser/issues/596 -TEST_F(NVFuserTest, ResizePadToBroadcastIssue596_CUDA) { +TEST_P(ResizeTest, ResizePadToBroadcastIssue596) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -2437,6 +2562,13 @@ TEST_F(NVFuserTest, ResizePadToBroadcastIssue596_CUDA) { auto t1 = at::randn({3}, options); std::vector aten_inputs({t0, t1}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); FusionKernelRuntime runtime(std::move(fusion), args); runtime.compileFusionParallel(args); @@ -2930,7 +3062,7 @@ TEST_F(ResizeTest, SliceAndReshapeRepro540Manual) { // Test concretizing a pad that follows a reshape. This requires the // ExpressionEvaluator used in concretization to propagate shapes properly // across symbolic reshapes in order to infer the size of the downstream pad. -TEST_F(ResizeTest, ReshapeToPad) { +TEST_P(ResizeTest, ReshapeToPad) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -2951,6 +3083,13 @@ TEST_F(ResizeTest, ReshapeToPad) { auto tv2 = pad(tv1, {fusion.zeroVal(), s0, fusion.zeroVal(), s1}); fusion.addOutput(tv2); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -3094,7 +3233,7 @@ TEST_F(ResizeTest, CatOfExpandedBroadcast) { // padded in the empty dim as well as the expanded dims. // This should match test_python_frontend.py::test_pad_expanded_empty // See https://github.com/NVIDIA/Fuser/issues/870 -TEST_F(ResizeTest, PadExpandedEmpty) { +TEST_P(ResizeTest, PadExpandedEmpty) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(&fusion); @@ -3127,6 +3266,13 @@ TEST_F(ResizeTest, PadExpandedEmpty) { auto t0 = at::randn({0}, options).as_strided({2, 0, 3}, {0, 0, 0}); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); @@ -3136,7 +3282,7 @@ TEST_F(ResizeTest, PadExpandedEmpty) { // Test that we can pad properly along broadcast dims // See https://github.com/NVIDIA/Fuser/issues/868 -TEST_F(ResizeTest, PadOfBroadcast) { +TEST_P(ResizeTest, PadOfBroadcast) { Fusion fusion; FusionGuard fg(&fusion); @@ -3153,6 +3299,13 @@ TEST_F(ResizeTest, PadOfBroadcast) { auto t0 = at::randn(shape0, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -3162,7 +3315,7 @@ TEST_F(ResizeTest, PadOfBroadcast) { // Test that we can cat along broadcast dims that have been expanded // See https://github.com/NVIDIA/Fuser/issues/868 -TEST_F(ResizeTest, PadOfExpandedBroadcast) { +TEST_P(ResizeTest, PadOfExpandedBroadcast) { Fusion fusion; FusionGuard fg(&fusion); @@ -3182,6 +3335,13 @@ TEST_F(ResizeTest, PadOfExpandedBroadcast) { auto t0 = at::randn(shape0, options); std::vector aten_inputs({t0}); + EnableOptionsGuard enable_options_guard; + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); @@ -3189,7 +3349,7 @@ TEST_F(ResizeTest, PadOfExpandedBroadcast) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, dynamicReshapeIssue1393) { +TEST_F(ResizeTest, DynamicReshapeIssue1393) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion* fusion = fusion_ptr.get(); FusionGuard fg(fusion); From 2edebb60130b6f17c5b328279b41d35c15a21fbd Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 27 Aug 2024 10:52:00 -0700 Subject: [PATCH 22/54] Clean test_pointwise.cpp. (#2855) --- tests/cpp/test_pointwise.cpp | 58 ++++++++++-------------------------- 1 file changed, 15 insertions(+), 43 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 14a02d7414f..356fe8568ca 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -23,17 +23,15 @@ using PointwiseTest = NVFuserTest; namespace { -size_t getVecSizeForPointwise(FusionExecutorCache& fec) { - auto most_recent_params = - fec.getMostRecentKernelRuntime()->getMostRecentExecutorLog().params; - const auto* params = dynamic_cast(most_recent_params.get()); - NVF_ERROR( - params != nullptr, - "`fec`'s contained fusion didn't trigger the pointwise scheduler."); - if (params->vectorize) { - return params->unroll_factor; +int64_t getVecSizeForPointwise(const FusionExecutorCache& fec) { + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + NVF_CHECK(!runtime->isSegmented()); + const PointwiseParams& params = + runtime->schedulerHeuristics()->heuristicsList().at(0)->pointwiseParams(); + if (!params.vectorize) { + return 1; } - return 1; + return params.unroll_factor; } bool hasVectorizationCache(TensorView* tv) { @@ -66,7 +64,6 @@ TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { fusion->addOutput(tv1); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); std::vector> size_and_vec{{17, 1}, {18, 2}, {32, 4}}; @@ -77,7 +74,7 @@ TEST_F(PointwiseTest, VectorizeStrideContiguity2D) { at::Tensor t0 = at::randn({1000000, size}, options).narrow(1, 0, 16); auto cg_outputs = fec.runFusionWithInputs({t0}); - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + EXPECT_EQ(getVecSizeForPointwise(fec), vec); testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); } @@ -95,7 +92,6 @@ TEST_F(PointwiseTest, VectorizeStrideContiguity3D) { fusion->addOutput(tv1); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); std::vector> size_and_vec{{17, 1}, {10, 2}, {16, 4}}; @@ -106,7 +102,7 @@ TEST_F(PointwiseTest, VectorizeStrideContiguity3D) { at::Tensor t0 = at::randn({1000000, size, 3}, options).narrow(1, 0, 8); auto cg_outputs = fec.runFusionWithInputs({t0}); - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + EXPECT_EQ(getVecSizeForPointwise(fec), vec); testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); } @@ -126,7 +122,6 @@ TEST_F(PointwiseTest, VectorizeStrideContiguity5D) { fusion->addOutput(tv1); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -142,7 +137,7 @@ TEST_F(PointwiseTest, VectorizeStrideContiguity5D) { .narrow(3, 0, 4); auto cg_outputs = fec.runFusionWithInputs({t0}); - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + EXPECT_EQ(getVecSizeForPointwise(fec), vec); testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); } @@ -165,7 +160,6 @@ TEST_F(PointwiseTest, VectorizeStrideMisalignedBase) { fusion->addOutput(tv1); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -202,7 +196,7 @@ TEST_F(PointwiseTest, VectorizeStrideMisalignedBase) { at::Tensor flat = at::randn({alloc_size}, options); at::Tensor t0 = flat.as_strided(shape, stride, /*storage_offset=*/align); auto cg_outputs = fec.runFusionWithInputs({t0}); - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + EXPECT_EQ(getVecSizeForPointwise(fec), vec); testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } @@ -221,7 +215,6 @@ TEST_F(PointwiseTest, VectorizeStrideContiguitySelfOverlapping) { fusion->addOutput(tv1); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -249,7 +242,7 @@ TEST_F(PointwiseTest, VectorizeStrideContiguitySelfOverlapping) { at::Tensor t0 = at::empty_strided(shape, stride, options); t0.random_(); auto cg_outputs = fec.runFusionWithInputs({t0}); - EXPECT_EQ(getVecSizeForPointwise(fec), (size_t)vec); + EXPECT_EQ(getVecSizeForPointwise(fec), vec); testValidate(fusion, cg_outputs, {t0}, __LINE__, __FILE__); } } @@ -270,7 +263,6 @@ TEST_F(PointwiseTest, VectorizeAllocationDomain) { fusion->addOutput(tv1); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = @@ -425,7 +417,6 @@ TEST_F(PointwiseTest, Issue1567VectorizationFactorAnalysisCase2) { fusion->addOutput(tv3); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1024, 1, 2}, options); @@ -466,7 +457,6 @@ TEST_F(PointwiseTest, VIssue1567ectorizationFactorAnalysisCase3) { fusion->addOutput(tv3); FusionExecutorCache fec(std::move(fusion_ptr)); - fec.profile(true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1, 1024, 2}, options); @@ -598,16 +588,7 @@ TEST_F(PointwiseTest, VectorizeWithBroadcastAndReshape1) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - NVF_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); - auto heuristic_params = executor_cache.getMostRecentKernelRuntime() - ->schedulerHeuristics() - ->heuristicsList() - .at(0) - ->params(); - ASSERT_TRUE(heuristic_params->isA()); - auto pparams = heuristic_params->as(); - ASSERT_TRUE(pparams->vectorize) << "Failed to vectorize"; - ASSERT_EQ(pparams->unroll_factor, 4) << "Unexpected vectorize factor"; + EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); } // Repro of issue #657 @@ -654,16 +635,7 @@ TEST_F(PointwiseTest, VectorizeWithBroadcastAndReshape2) { FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - NVF_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); - auto heuristic_params = executor_cache.getMostRecentKernelRuntime() - ->schedulerHeuristics() - ->heuristicsList() - .at(0) - ->params(); - ASSERT_TRUE(heuristic_params->isA()); - auto pparams = heuristic_params->as(); - ASSERT_TRUE(pparams->vectorize) << "Failed to vectorize"; - ASSERT_EQ(pparams->unroll_factor, 4) << "Unexpected vectorize factor"; + EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); } } // namespace nvfuser From 77c300954ca523595af517e01220f7b58c00b3ac Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 27 Aug 2024 13:29:06 -0700 Subject: [PATCH 23/54] Fix the computation of input_discontig_strides_. (#2854) --- csrc/scheduler/registry.cpp | 10 +++++++--- csrc/tensor_view.cpp | 6 +++++- tests/cpp/test_pointwise.cpp | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 6c162ce05e0..9800576758d 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -81,14 +81,18 @@ SchedulerRuntimeInfo::SchedulerRuntimeInfo( // find and push discontiguous stride int64_t dtype_size = dataTypeSize(input_tv->dtype()); input_discontig_strides_[fusion_inp] = {}; - int64_t dims = (int64_t)alloc_strides.size(); + auto dims = static_cast(alloc_strides.size()); int64_t expected_stride = 1; for (int64_t dim = dims - 1; dim >= 0; dim--) { auto size = alloc_sizes.at(dim); - if (size <= 1) { + auto stride = alloc_strides.at(dim); + // Skip broadcast dimensions because they don't affect contiguity. + // Consider to change this to check IterDomain::isBroadcast instead: + // https://github.com/NVIDIA/Fuser/pull/2854#discussion_r1733205035 + if (size <= 1 || stride == 0) { continue; } - auto stride = alloc_strides.at(dim); + if (stride != expected_stride) { input_discontig_strides_[fusion_inp].push_back(stride * dtype_size); expected_stride = stride; diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index a83fe3d69b3..8a34db6b7aa 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -1485,7 +1485,11 @@ TensorViewBuilder& TensorViewBuilder::strideOrder( TensorViewBuilder& TensorViewBuilder::expanded(std::vector expanded) { NVF_CHECK(expanded_.empty(), "Attempting to reset expanded shape"); if (!expanded.empty()) { - NVF_CHECK(ndims_ == 0 || ndims_ == (int64_t)expanded.size()); + NVF_CHECK( + ndims_ == 0 || ndims_ == (int64_t)expanded.size(), + ndims_, + " vs ", + expanded.size()); ndims_ = (int64_t)expanded.size(); } expanded_ = std::move(expanded); diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 356fe8568ca..515b710d830 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -638,4 +638,38 @@ TEST_F(PointwiseTest, VectorizeWithBroadcastAndReshape2) { EXPECT_EQ(getVecSizeForPointwise(executor_cache), 4); } +TEST_F(PointwiseTest, VectorizeWithExpandedBroadcast) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int64_t kTensorSize = 65536; + TensorView* in = TensorViewBuilder() + .dtype(DataType::Half) + .shape({2, kTensorSize}) + .expanded({true, false}) + .build(); + in->setAllocationDomain({in->axis(1), in->axis(0)}, true); + TensorView* out = add(in, in); + fusion->addInput(in); + fusion->addOutput(out); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto in_tensor = + at::randn({kTensorSize}, options).as_strided({2, kTensorSize}, {0, 1}); + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + auto hparams = fec.getMostRecentKernelRuntime() + ->schedulerHeuristics() + ->heuristicsList() + .at(0) + ->params(); + ASSERT_TRUE(hparams->isA()); + const auto& pparams = hparams->as(); + EXPECT_TRUE(pparams->vectorize); + EXPECT_GT(pparams->unroll_factor, 1); +} + } // namespace nvfuser From 692ee6e4e47a50dcd6b07556ccd0f993df74204e Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 27 Aug 2024 15:38:59 -0700 Subject: [PATCH 24/54] A minor cleanup on top of #2854 and #2855. (#2859) --- tests/cpp/test_pointwise.cpp | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 515b710d830..1ec6ce45cad 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -661,15 +661,7 @@ TEST_F(PointwiseTest, VectorizeWithExpandedBroadcast) { auto out_tensors = fec.runFusionWithInputs({in_tensor}); testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); - auto hparams = fec.getMostRecentKernelRuntime() - ->schedulerHeuristics() - ->heuristicsList() - .at(0) - ->params(); - ASSERT_TRUE(hparams->isA()); - const auto& pparams = hparams->as(); - EXPECT_TRUE(pparams->vectorize); - EXPECT_GT(pparams->unroll_factor, 1); + EXPECT_GT(getVecSizeForPointwise(fec), 1); } } // namespace nvfuser From 7fdcb3c140844e2888f2bfa08f9778581bf9d71b Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 27 Aug 2024 18:02:17 -0700 Subject: [PATCH 25/54] Remove unnecessary calls to `manual_seed`. (#2863) NVFuserTest already sets up the seed for determinism. --- tests/cpp/test_gpu3.cpp | 6 ------ tests/cpp/test_pointwise.cpp | 2 -- 2 files changed, 8 deletions(-) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 61caf77ab1e..4cc1ac113f4 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -6245,7 +6245,6 @@ TEST_F(NVFuserTest, FusionAvoidRedundantWriteBroadcastedSoftmaxInput_CUDA) { fusion.addOutput(tv4); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); at::Tensor t0 = at::ones(shape0, options); at::Tensor t1 = at::ones(shape1, options); std::vector inputs = {t0, t1}; @@ -6301,7 +6300,6 @@ TEST_F(NVFuserTest, FusionAvoidRedundantWrite_CUDA) { fusion.addOutput(tv4); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); at::Tensor t0 = at::randn(shape0, options); at::Tensor t1 = at::randn(shape1, options); std::vector inputs = {t0, t1}; @@ -6391,7 +6389,6 @@ TEST_F(NVFuserTest, FusionAvoidRedundantWriteDifferentConcretizedDomains_CUDA) { fusion.addOutput(tv8); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); at::Tensor t0 = at::randn(shape0, options); at::Tensor t1 = at::randn(shape1, options); at::Tensor t2 = at::randn(shape2, options); @@ -6453,7 +6450,6 @@ TEST_F(NVFuserTest, FusionAvoidRedundantWriteNonOutput_CUDA) { } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); at::Tensor t0 = at::randn({32}, options); at::Tensor t1 = at::randn({32, 64}, options); std::vector inputs = {t0, t1}; @@ -6518,7 +6514,6 @@ TEST_F(NVFuserTest, FusionAvoidRedundantWriteNonNeighbor_CUDA) { } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); at::Tensor t0 = at::randn({8, 10, 12}, options); at::Tensor t1 = at::randn({8, 7, 10, 12, 9}, options); std::vector inputs = {t0, t1}; @@ -7717,7 +7712,6 @@ TEST_F(NVFuserTest, PredicateRNGOps) { FusionExecutor fe; fe.compileFusion(fusion, {t0}); - at::manual_seed(0); auto cg_outputs = fe.runFusion({t0}); } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 1ec6ce45cad..c8bb0566e1c 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -580,7 +580,6 @@ TEST_F(PointwiseTest, VectorizeWithBroadcastAndReshape1) { fusion->addOutput(tv4); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); auto t0 = at::randn(shape1, options); auto t1 = at::randn(shape2, options); std::vector aten_inputs({t0, t1}); @@ -626,7 +625,6 @@ TEST_F(PointwiseTest, VectorizeWithBroadcastAndReshape2) { fusion->addOutput(tv7); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); auto t0 = at::randn(shape1, options); auto t1 = at::randn(shape1, options); auto t2 = at::randn(shape2, options); From 2eeef463ac615bdecfbef97e7d6430c2a210dea4 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Wed, 28 Aug 2024 10:07:06 -0700 Subject: [PATCH 26/54] Annotate inputs and outputs of transformer tests (#2858) --- benchmarks/python/test_transformer.py | 88 ++++++++++++++++++++------- 1 file changed, 66 insertions(+), 22 deletions(-) diff --git a/benchmarks/python/test_transformer.py b/benchmarks/python/test_transformer.py index ee2b4b0288a..47b25925590 100644 --- a/benchmarks/python/test_transformer.py +++ b/benchmarks/python/test_transformer.py @@ -44,6 +44,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: + # x: input T0 = fd.define_tensor( shape=[1, -1, -1], contiguity=[None, True, True], @@ -51,6 +52,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[2, 1, 0], ) + # layer_norm0.weight T1 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -58,6 +60,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # layer_norm0.bias T2 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -65,6 +68,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MHA linear0.weight T3 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -72,6 +76,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MHA linear0.bias T4 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -79,6 +84,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MHA linear1.weight T5 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -86,6 +92,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MHA linear1.bias T6 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -93,8 +100,11 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MHA dropout.rng_offset S7 = fd.define_scalar(None, dtype=DataType.Int) + # MHA dropout.rng_seed S8 = fd.define_scalar(None, dtype=DataType.Int) + # layer_norm1.weight T9 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -102,6 +112,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # layer_norm1.bias T10 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -109,6 +120,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MLP linear0.weight T11 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -116,6 +128,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MLP linear0.bias T12 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -123,6 +136,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MLP linear1.weight T13 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -130,6 +144,7 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MLP linear1.bias T14 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -137,8 +152,11 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MLP dropout.rng_offset S15 = fd.define_scalar(None, dtype=DataType.Int) + # MLP dropout.rng_seed S16 = fd.define_scalar(None, dtype=DataType.Int) + T17 = fd.ops.cast(T0, dtype=DataType.Float) T18, T19 = fd.ops.var_mean(T17, dims=[2], correction=0, keepdim=False) S20 = fd.define_scalar(1, dtype=DataType.Int) @@ -322,15 +340,15 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None: T186 = fd.ops.mul(T184, S185) T187 = fd.ops.add(T113, T186) T188 = fd.ops.cast(T187, dtype=DataType.BFloat16) - fd.add_output(T19) - fd.add_output(T32) - fd.add_output(T87) - fd.add_output(T88) - fd.add_output(T89) - fd.add_output(T90) - fd.add_output(T115) - fd.add_output(T128) - fd.add_output(T188) + fd.add_output(T19) # layer_norm0.welford_out.avg + fd.add_output(T32) # layer_norm0.invstd + fd.add_output(T87) # MHA sdpa.output + fd.add_output(T88) # MHA sdpa.logsum_exp + fd.add_output(T89) # MHA sdpa.philox_seed + fd.add_output(T90) # MHA sdpa.philox_offset + fd.add_output(T115) # layer_norm1.welford_out.avg + fd.add_output(T128) # layer_norm1.invstd + fd.add_output(T188) # output def test_transformer_forward( @@ -392,6 +410,7 @@ def test_transformer_forward( def transformer_backward_fusion(fd: FusionDefinition) -> None: + # x: input T0 = fd.define_tensor( shape=[1, -1, -1], contiguity=[None, True, True], @@ -399,6 +418,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[2, 1, 0], ) + # layer_norm0.welford_out.avg T1 = fd.define_tensor( shape=[1, -1], contiguity=[None, True], @@ -406,6 +426,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # layer_norm0.invstd T2 = fd.define_tensor( shape=[1, -1, 1], contiguity=[None, True, None], @@ -413,6 +434,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[2, 1, 0], ) + # layer_norm0.weight T3 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -420,6 +442,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # layer_norm0.bias T4 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -427,6 +450,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MHA linear0.weight T5 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -434,6 +458,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MHA linear0.bias T6 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -441,6 +466,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MHA sdpa.output T7 = fd.define_tensor( shape=[1, -1, -1, -1], contiguity=[None, True, True, True], @@ -448,6 +474,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[3, 1, 2, 0], ) + # MHA linear1.weight T8 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -455,6 +482,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MHA linear1.bias T9 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -462,8 +490,11 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MHA dropout.rng_offset S10 = fd.define_scalar(None, dtype=DataType.Int) + # MHA dropout.rng_seed S11 = fd.define_scalar(None, dtype=DataType.Int) + # layer_norm1.welford_out.avg T12 = fd.define_tensor( shape=[1, -1], contiguity=[None, True], @@ -471,6 +502,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # layer_norm1.invstd T13 = fd.define_tensor( shape=[1, -1, 1], contiguity=[None, True, None], @@ -478,6 +510,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[2, 1, 0], ) + # layer_norm1.weight T14 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -485,6 +518,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # layer_norm1.bias T15 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -492,6 +526,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MLP linear0.weight T16 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -499,6 +534,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MLP linear0.bias T17 = fd.define_tensor( shape=[-1], contiguity=[True], @@ -506,8 +542,11 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[0], ) + # MLP dropout.rng_offset S18 = fd.define_scalar(None, dtype=DataType.Int) + # MLP dropout.rng_seed S19 = fd.define_scalar(None, dtype=DataType.Int) + # dy: incoming grad T20 = fd.define_tensor( shape=[1, -1, -1], contiguity=[None, True, True], @@ -515,6 +554,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[2, 1, 0], ) + # MLP linear1.weight T21 = fd.define_tensor( shape=[-1, -1], contiguity=[True, True], @@ -522,6 +562,7 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[1, 0], ) + # MHA sdpa.logsum_exp T22 = fd.define_tensor( shape=[1, -1, -1], contiguity=[None, True, True], @@ -529,8 +570,11 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: is_cpu=False, stride_order=[2, 1, 0], ) + # MHA sdpa.philox_seed T23 = fd.define_tensor(shape=[], contiguity=[], dtype=DataType.Int, is_cpu=False) + # MHA sdpa.philox_offset T24 = fd.define_tensor(shape=[], contiguity=[], dtype=DataType.Int, is_cpu=False) + T25 = fd.ops.cast(T0, dtype=DataType.Float) S26 = fd.define_scalar(1, dtype=DataType.Int) S27 = fd.define_scalar(2048, dtype=DataType.Int) @@ -979,19 +1023,19 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None: T454 = fd.ops.add(T384, T453) T455 = fd.ops.add(T304, T454) T456 = fd.ops.cast(T455, dtype=DataType.BFloat16) - fd.add_output(T184) - fd.add_output(T186) - fd.add_output(T223) - fd.add_output(T225) - fd.add_output(T228) - fd.add_output(T232) - fd.add_output(T324) - fd.add_output(T326) - fd.add_output(T373) - fd.add_output(T376) - fd.add_output(T379) - fd.add_output(T383) - fd.add_output(T456) + fd.add_output(T184) # MLP linear1.weight_grad + fd.add_output(T186) # MLP linear1.bias_grad + fd.add_output(T223) # MLP linear0.weight_grad + fd.add_output(T225) # MLP linear0.bias_grad + fd.add_output(T228) # layer_norm1.bias_grad + fd.add_output(T232) # layer_norm1.weight_grad + fd.add_output(T324) # MHA linear1.weight_grad + fd.add_output(T326) # MHA linear1.bias_grad + fd.add_output(T373) # MHA linear0.weight_grad + fd.add_output(T376) # MHA linear0.bias_grad + fd.add_output(T379) # layer_norm0.bias_grad + fd.add_output(T383) # layer_norm0.weight_grad + fd.add_output(T456) # dx output grad def test_transformer_backward( From 13fba3b62c7b2230405c4b240b4bdaca6a363a1c Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 28 Aug 2024 15:33:32 -0400 Subject: [PATCH 27/54] Remove isProducerOf/isConsumerOf (#2867) Remove isProducerOf/isConsumerOf, replace the one instance. Original definitions were backwards. --- csrc/ir/base_nodes.cpp | 17 ----------------- csrc/ir/base_nodes.h | 6 ------ csrc/logical_domain_map.cpp | 14 ++++++++++---- 3 files changed, 10 insertions(+), 27 deletions(-) diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 16c2b61b397..5c22294b698 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -248,23 +248,6 @@ std::optional Val::getDataType() const { return dtype_; } -bool Val::isProducerOf(const Val* other) const { - NVF_ERROR(other != nullptr); - NVF_ERROR(container() == other->container()); - - if (definition() == nullptr) { - return false; - } - return std::any_of( - definition()->inputs().begin(), - definition()->inputs().end(), - [other](const Val* input) { return input == other; }); -} - -bool Val::isConsumerOf(const Val* other) const { - return other->isProducerOf(this); -} - // We don't register with the active fusion in Expr as this needs to be done // after inputs and outputs are registered with the Expr Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {} diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 024fe3cbf4b..efcb2e475b6 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -373,12 +373,6 @@ class NVF_API Val : public Statement { return is_fusion_output_; } - //! Returns true when other is a producer of this - bool isProducerOf(const Val* other) const; - - //! Returns true when other is a consumer of this - bool isConsumerOf(const Val* other) const; - bool sameType(const Statement* other) override { return Statement::sameType(other) && getDataType() == other->as()->getDataType(); diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 6e67ff84b63..b337b78a3e1 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -57,13 +57,19 @@ PairwiseLogicalDomainMap::PairwiseLogicalDomainMap( NVF_ERROR(producer != nullptr); NVF_ERROR(consumer != nullptr); NVF_ERROR(producer->fusion() == consumer->fusion()); + NVF_ERROR(consumer->definition() != nullptr); + auto producer_tvs_of_consumer = ir_utils::producerTvsOf(consumer); // Make sure they are really a producer and its consumer NVF_ERROR( - producer->isConsumerOf(consumer), - "Not a producer-consumer pair: ", + std::find( + producer_tvs_of_consumer.begin(), + producer_tvs_of_consumer.end(), + producer) != producer_tvs_of_consumer.end(), + "Expected ", producer, - ", ", - consumer); + " is a producer of ", + consumer, + " but it is not."); } namespace { From 148e3dca3ff9a50b6809e2f0bde68c73d235f831 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 28 Aug 2024 15:55:13 -0700 Subject: [PATCH 28/54] Clean a test. (#2861) The main change is to use testValidate to compute reference outputs. --- .../test_combined_inner_outer_reduction.cpp | 74 +++++-------------- 1 file changed, 19 insertions(+), 55 deletions(-) diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index a7bd303d7c3..55873d66bb1 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -388,9 +388,9 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { fusion.addOutput(use_outer); } break; case 1: { - // tensor bias is a producer of the inner reduction and also a produer - // of a consumer of the outer reduction results this a not allowed, - // expect segmented + // tensor bias is a producer of the inner reduction and also a + // produer of a consumer of the outer reduction results this a not + // allowed, expect segmented auto bias_broad = add(bias, mean); auto use_inner = sum(bias_broad, {-1}); auto use_outer = add(layer_norm_results.grad_weight, bias); @@ -398,10 +398,10 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { fusion.addOutput(use_outer); } break; case 2: { - // tensor bias is a producer of the outer reduction and also a produer - // of a consumer of the inner reduction results this a allowed, becase - // the first part of outer reduction is computed with inner reduction. - // expect unsegmented + // tensor bias is a producer of the outer reduction and also a + // produer of a consumer of the inner reduction results this a + // allowed, becase the first part of outer reduction is computed + // with inner reduction. expect unsegmented auto bias_broad = add(bias, mean); auto use_inner = add(layer_norm_results.grad_input, bias_broad); auto use_outer = sum(bias_broad, {0}); @@ -435,12 +435,10 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { at::Tensor aten_input = at::randn(input_shape, maybe_fp16_options); at::Tensor aten_weight = at::randn(norm_shape, maybe_fp16_options); at::Tensor aten_bias = at::randn(norm_shape, maybe_fp16_options); - auto at_weight = c10::optional(aten_weight); - auto at_bias = c10::optional(aten_bias); - const float kEps = 1e-5; - auto aten_results = - at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps); + constexpr float kEps = 1e-5; + auto aten_results = at::native_layer_norm( + aten_input, norm_shape, aten_weight, aten_bias, kEps); auto aten_output = std::get<0>(aten_results); auto aten_mean = std::get<1>(aten_results); auto aten_rstd = std::get<2>(aten_results); @@ -455,48 +453,19 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { aten_bias}; auto cg_outputs = fec.runFusionWithInputs(aten_inputs); - auto aten_gradients = at::native_layer_norm_backward( - aten_grad_out, - aten_input, - norm_shape, - aten_mean, - aten_rstd, - c10::optional(aten_weight), - c10::optional(aten_bias), - {true, true, true}); - - // check the results depending on the case - at::Tensor aten_use_inner, aten_use_outer; - bool expected_segmented; + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); switch (case_id) { - case 0: { - aten_use_inner = std::get<0>(aten_gradients) + aten_input; - aten_use_outer = std::get<1>(aten_gradients) + aten_input; - expected_segmented = true; - } break; - case 1: { - aten_use_inner = (aten_bias + aten_mean).sum({-1}); - aten_use_outer = std::get<1>(aten_gradients) + aten_bias; - expected_segmented = true; - } break; - case 2: { - aten_use_inner = std::get<0>(aten_gradients) + (aten_bias + aten_mean); - aten_use_outer = (aten_bias + aten_mean).sum({0}); - expected_segmented = false; - } break; - case 3: { - aten_use_inner = std::get<1>(aten_gradients) + (aten_bias + 1.0); - aten_use_outer = std::get<2>(aten_gradients) + (aten_bias + 1.0); - expected_segmented = true; - } break; + case 0: + case 1: + case 3: + EXPECT_TRUE(runtime->isSegmented()); + break; + case 2: + EXPECT_FALSE(runtime->isSegmented()); + break; default: NVF_ERROR(false, "Invalid case id"); } - bool is_segmented = fec.getMostRecentKernelRuntime()->isSegmented(); - NVF_CHECK( - is_segmented == expected_segmented, - expected_segmented ? "Fusion should be segmented!" - : "Fusion should not be segmented!"); auto tolerance_overwrite = ValidationConstants(); // bump tolerance, CI errors are higher than local @@ -510,11 +479,6 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { &fusion, cg_outputs, aten_inputs, - {aten_use_inner, - aten_use_outer, - std::get<0>(aten_gradients), - std::get<1>(aten_gradients), - std::get<2>(aten_gradients)}, __LINE__, __FILE__, "", From 58dfdc1b0d4833b56fb11762d7121bdd43c1e766 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 28 Aug 2024 16:39:56 -0700 Subject: [PATCH 29/54] Split an NVF_CHECK into two. (#2868) --- csrc/device_lower/validation.cpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index fd55d0705cd..5091e655556 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -401,18 +401,26 @@ class VectorizeValidator : public OptInDispatch { if (!is_ldmatrix_trans) { // ldmatrix.trans is a hardware transpose instruction that can do // "vectorized" read from discontiguous memory - auto contiguity = tv->domain()->contiguity().at(last_alloc_dim_pos); NVF_CHECK( - last_alloc_dim == validator.vectorized_id_ && - contiguity.value_or(false), + last_alloc_dim == validator.vectorized_id_, "Vectorized dim for ", name, - " has to be from a contiguous inner most position. tv: ", + " has to be from an inner most position. tv: ", tv, ", allocation domain: ", - ir_utils::toString(tv->getMaybeAllocationDomain()), + tv->getMaybeAllocationDomain(), ", vectorized id: ", - validator.vectorized_id_->toString(), + validator.vectorized_id_, + ", innermost id: ", + last_alloc_dim); + + auto contiguity = tv->domain()->contiguity().at(last_alloc_dim_pos); + NVF_CHECK( + contiguity.value_or(false), + "The innermost position has to be contiguous. tv: ", + tv, + ", allocation domain: ", + tv->getMaybeAllocationDomain(), ", innermost id: ", last_alloc_dim->toString(), ", contiguity: ", From 4bc88d65d39bd8a3c455fd40f4104108477fc718 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 30 Aug 2024 14:53:14 -0700 Subject: [PATCH 30/54] Remove dead code in sync_information.cpp. (#2874) --- csrc/device_lower/analysis/sync_information.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index e1a3f5324b1..bea1c7ac150 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -468,7 +468,6 @@ SyncMap::SyncMap(Fusion* fusion) { // Stash information about parallelized producer iteration domains std::vector producer_parallel_ids( ParallelTypeBitmap::kNumParallelTypes, nullptr); - ParallelTypeBitmap producer_parallel_bitmap; // Get the parallel types that producer will be predicated off in producer // writes. @@ -507,7 +506,6 @@ SyncMap::SyncMap(Fusion* fusion) { continue; } - producer_parallel_bitmap.set(producer_ptype); producer_parallel_ids[getParallelTypeBitMapOffset(producer_ptype)] = producer_axis; } @@ -517,7 +515,6 @@ SyncMap::SyncMap(Fusion* fusion) { // Stash information about parallelized consumer iteration domains std::vector consumer_parallel_ids( ParallelTypeBitmap::kNumParallelTypes, nullptr); - ParallelTypeBitmap consumer_parallel_bitmap; for (const auto consumer_i : c10::irange(consumer->nDims())) { auto consumer_axis = consumer->axis(consumer_i); auto consumer_ptype = @@ -538,7 +535,6 @@ SyncMap::SyncMap(Fusion* fusion) { continue; } - consumer_parallel_bitmap.set(consumer_ptype); consumer_parallel_ids[getParallelTypeBitMapOffset(consumer_ptype)] = consumer_axis; } From e0c7da2a18b75b69c12fb45c1884322741a6aeba Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Fri, 30 Aug 2024 16:22:02 -0700 Subject: [PATCH 31/54] Prioritize using zero index (#2876) Fixes https://github.com/NVIDIA/Fuser/issues/2763 --- csrc/id_model/indexing.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index 2308e576ac9..94221cc3a76 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -789,15 +789,15 @@ void TensorIndexer::buildLoopIndexMap() { Val* loop_index = nullptr; ParallelType ptype = getParallelType(loop_group); - if (isParallelTypeThread(ptype)) { - loop_index = NamedScalar::getParallelIndex(ptype); - } else if ( + if ( // TODO: Cleanup needed. ir_utils::isMemoryPartitionedAcross // should be used, but that means we would need to consider // multiple outputs with different memory types, though it // should be uncommon in practice. shouldUseZeroIndex(loop_group) || isParallelTypeDeviceDim(ptype)) { loop_index = fusion->zeroVal(); + } else if (isParallelTypeThread(ptype)) { + loop_index = NamedScalar::getParallelIndex(ptype); } else { // Until the transition to the IdModel-based indexing is // completed, use the index Vals assigned for ComputeAtMap From bc0edaf4145618327c22d75e312d60013dcf568b Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 30 Aug 2024 16:46:20 -0700 Subject: [PATCH 32/54] Naming changes. (#2862) --- tests/cpp/test_combined_inner_outer_reduction.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 55873d66bb1..5e4fde6bb18 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -163,7 +163,7 @@ INSTANTIATE_TEST_SUITE_P( // fusion should be segmented since the current combined scheduler assumes there // is no shared consumer between inter reductions and outer reductions and among // tensors in outer reductions. -TEST_F(NVFuserTest, CombinedSchedulerSharedConsumer_CUDA) { +TEST_F(CombinedSchedulerTest, SharedConsumer) { auto runTest = [](const std::vector& batch_shape, const std::vector& norm_shape, DataType dtype, @@ -312,7 +312,7 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedConsumer_CUDA) { // This case is to test the correctness of the combined inner and outer // scheduler. One tensor is using the inner reduction results and outer // reduction results. should be segmented. -TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { +TEST_F(CombinedSchedulerTest, SharedProducer) { auto runTest = [](const std::vector& batch_shape, const std::vector& norm_shape, DataType dtype, @@ -497,7 +497,7 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { } // Manual schedule of inner and outer reduction on the same tensor -TEST_F(NVFuserTest, CombinedReduction_CUDA) { +TEST_F(CombinedSchedulerTest, CombinedReduction) { // https://github.com/csarofeen/pytorch/issues/2566 // this case will fail, if using tidx = 8 and tidy = 64 // for inner reduction, tidy is derived as 10240 / (tidx*vecx*nloadx) = 64 @@ -658,7 +658,7 @@ TEST_F(NVFuserTest, CombinedReduction_CUDA) { // Manual schedule of inner and outer reduction on the same tensor. Each block // will do multiple reductions. -TEST_F(NVFuserTest, CombinedReductionMultiPerBlock_CUDA) { +TEST_F(CombinedSchedulerTest, CombinedReductionMultiPerBlock) { auto ceilDiv = [](const int a, const int b) { return (a + b - 1) / b; }; constexpr bool verbose = false; const auto dev_prop = at::cuda::getCurrentDeviceProperties(); @@ -836,7 +836,7 @@ TEST_F(NVFuserTest, CombinedReductionMultiPerBlock_CUDA) { // Reproduce of issue 1023, where iteration axis in inner reduction tv doesn't // match to reduction axis in outer reduction tv. -TEST_F(NVFuserTest, CombinedSchedulerInnerOuterMismatch) { +TEST_F(CombinedSchedulerTest, InnerOuterMismatch) { auto test = [](const std::vector& outer_reduction_axis) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); @@ -891,7 +891,7 @@ TEST_F(NVFuserTest, CombinedSchedulerInnerOuterMismatch) { // outer broadcast tvs, e.g. in layer norm backward and RMS norm backward. // This test covers the branch where the outer broadcast tensor is not exist // and data type is fp32, so the buffer is not projected to inputs. -TEST_F(NVFuserTest, CombinedSchedulerInnerOuterNoOuterBroadcastTv) { +TEST_F(CombinedSchedulerTest, InnerOuterNoOuterBroadcastTv) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); From 1de3cd465add106f96e59af7096ab5a673747d37 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 30 Aug 2024 17:53:53 -0700 Subject: [PATCH 33/54] Fix the coding style. (#2879) --- csrc/ops/arith.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 399b128672f..4026d7b7678 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -1176,16 +1176,16 @@ TensorView* newForReduction( auto reduced_axis_iter = axes_set.begin(); for (const auto dim : c10::irange(orig_domain.size())) { - bool isReduction = false; + bool is_reduction = false; if (reduced_axis_iter != axes_set.end() && *reduced_axis_iter == dim) { - isReduction = true; + is_reduction = true; reduced_axis_iter++; } const IterDomain* id = orig_domain[dim]; IterDomain* new_id = nullptr; - if (isReduction) { + if (is_reduction) { if (id->isBroadcast()) { NVF_CHECK( id->isImplicitBroadcast(), @@ -2345,9 +2345,9 @@ static TensorView* newForMma( auto axis_iter = axes_set.begin(); for (const auto dim : c10::irange(orig_domain_a.size())) { - bool isReduction = false; + bool is_reduction = false; if (axis_iter != axes_set.end() && *axis_iter == dim) { - isReduction = true; + is_reduction = true; axis_iter++; } @@ -2356,7 +2356,7 @@ static TensorView* newForMma( : orig_domain_a[dim]; NVF_CHECK( - !(isReduction && id->isBroadcast() && !id->isImplicitBroadcast()), + !(is_reduction && id->isBroadcast() && !id->isImplicitBroadcast()), "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ", id, " of tensor ", @@ -2367,7 +2367,7 @@ static TensorView* newForMma( new_domain.push_back( IterDomainBuilder(id->start(), id->extent()) .stop_offset(id->stopOffset()) - .iter_type(isReduction ? IterType::Reduction : id->getIterType()) + .iter_type(is_reduction ? IterType::Reduction : id->getIterType()) .build()); } From 5bc8d3e554a35a3c37a0f7e9f601af7bef9d4032 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 1 Sep 2024 15:22:51 -0700 Subject: [PATCH 34/54] Fix the handling of `b(DID)`. (#2877) --- csrc/evaluator_common.cpp | 27 ++++--- csrc/expr_evaluator.cpp | 100 +++++++++++++----------- csrc/ops/utils.cpp | 14 ++-- tests/cpp/test_multidevice_sharding.cpp | 58 ++++++++++++++ 4 files changed, 134 insertions(+), 65 deletions(-) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 1042931914e..c7a45c79359 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -348,19 +348,22 @@ void PrecomputedValues::bindTensorMetaData( for (const auto dim : c10::irange(logical_domain.size())) { IterDomain* id = logical_domain[dim]; - auto dim_size = tensor.size(static_cast(dim)); - if (id->isDeviceDim()) { - dim_size = tv->getDeviceMesh().size(id->getParallelType()); - } - - if (id->hasExpandedExtent()) { - Val* extent = id->extent(); - Val* expanded_extent = id->expandedExtent(); - bindValue(extent->evaluatorIndex(), 1L); - bindValue(expanded_extent->evaluatorIndex(), dim_size); + const auto dim_size = tensor.size(static_cast(dim)); + if (id->isBroadcast()) { + // DIDs are ignored for broadcast. See MultideviceShardingTest.Broadcast + // and .ExpandedBroadcast. + bindValue(id->extent()->evaluatorIndex(), 1L); + if (id->hasExpandedExtent()) { + bindValue(id->expandedExtent()->evaluatorIndex(), dim_size); + } } else { - Val* extent = id->extent(); - bindValue(extent->evaluatorIndex(), dim_size); + if (id->isDeviceDim()) { + bindValue( + id->extent()->evaluatorIndex(), + tv->getDeviceMesh().size(id->getParallelType())); + } else { + bindValue(id->extent()->evaluatorIndex(), dim_size); + } } } diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index c59842b2037..85476396c6a 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -174,53 +174,61 @@ void ExpressionEvaluator::bind_( t.dim()); for (auto i : c10::irange(t.dim())) { auto id = logical_domain[i]; - if (id->hasExpandedExtent()) { - // Verify that t is also expanded - NVF_ERROR( - t.size(i) == 1 || t.stride(i) == 0, - "IterDomain ", - id->toString(), - " in ", - getInputPosString(tv), - "TensorView ", - tv->toString(), - " has expanded extent but input tensor has size ", - t.size(i), - " and stride ", - t.stride(i), - " in dimension ", - i); - bind_( - logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); - } else if (logical_domain[i]->isDeviceDim()) { - // Currently we have the restrictions: - // (1) Devices parallelized axis extent == DeviceMesh's extent - // (2) Device parallelized axis cannot be split or merged - // Therefore, the device parallelized extents will always be allocated - // with size 1, but the symbolic axis extent is binded with the extent - // of the DeviceMesh - NVF_CHECK( - 1 == t.size(i), - "TensorView ", - tv->toString(), - getInputPosString(tv), - " IterDomain ", - id->toString(), - "is sharded and must have size 1, but input tensor has size ", - t.size(i)); - NVF_CHECK( - tv->hasDeviceMesh(), - "TV ", - tv->toString(), - getInputPosString(tv), - " has an empty DeviceMesh with DID parallelization") - bind_( - logical_domain[i]->extent(), - static_cast( - tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), - evaluate_validate); + if (id->isBroadcast()) { + // DIDs are ignored for broadcast. + bind_(logical_domain[i]->extent(), 1, evaluate_validate); + if (id->hasExpandedExtent()) { + // Verify that t is also expanded + NVF_ERROR( + t.size(i) == 1 || t.stride(i) == 0, + "IterDomain ", + id->toString(), + " in ", + getInputPosString(tv), + "TensorView ", + tv->toString(), + " has expanded extent but input tensor has size ", + t.size(i), + " and stride ", + t.stride(i), + " in dimension ", + i); + bind_( + logical_domain[i]->expandedExtent(), + t.size(i), + evaluate_validate); + } } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + if (logical_domain[i]->isDeviceDim()) { + // Currently we have the restrictions: + // (1) Devices parallelized axis extent == DeviceMesh's extent + // (2) Device parallelized axis cannot be split or merged + // Therefore, the device parallelized extents will always be allocated + // with size 1, but the symbolic axis extent is binded with the extent + // of the DeviceMesh + NVF_CHECK( + 1 == t.size(i), + "TensorView ", + tv->toString(), + getInputPosString(tv), + " IterDomain ", + id->toString(), + "is sharded and must have size 1, but input tensor has size ", + t.size(i)); + NVF_CHECK( + tv->hasDeviceMesh(), + "TV ", + tv->toString(), + getInputPosString(tv), + " has an empty DeviceMesh with DID parallelization") + bind_( + logical_domain[i]->extent(), + static_cast(tv->getDeviceMesh().size( + logical_domain[i]->getParallelType())), + evaluate_validate); + } else { + bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + } } } } diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 8b911d623c0..3b6bf6d561f 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -323,6 +323,13 @@ IterDomain* newOutputIterDomain( continue; } + NVF_ERROR( + id->getParallelType() == ParallelType::Serial || + isParallelTypeDeviceDim(id->getParallelType()), + id->getParallelType(), + " is not expected when building ops."); + parallel_type = promoteParallelType(parallel_type, id->getParallelType()); + if (id->isBroadcast()) { if (id->hasExpandedExtent()) { expanded_extent_val = @@ -331,13 +338,6 @@ IterDomain* newOutputIterDomain( continue; } - NVF_ERROR( - id->getParallelType() == ParallelType::Serial || - isParallelTypeDeviceDim(id->getParallelType()), - id->getParallelType(), - " is not expected when building ops."); - parallel_type = promoteParallelType(parallel_type, id->getParallelType()); - if (extent_is_from_symbolic && !id->isSymbolic()) { // We prefer to use extents from non-Symbolic inputs if there are any // because they might indicate a broadcast axis that is resolved in this diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 18246a77239..fbccf278911 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -230,4 +230,62 @@ TEST_F(MultideviceShardingTest, Issue2758) { __FILE__); } +// This test and the following `ExpandedBroadcast` test verify the expression +// evaluator correctly binds the extent of a broadcast dimension to 1 and the +// expanded extent to the tensor size. There used to be a bug where it +// incorrectly binds the extent(s) to the mesh size. +// +// `b(DID{i0})` and `b(i0)` bear the same semantics. The former is used more +// often due to how parallelizeAllLike is implemented. +TEST_F(MultideviceShardingTest, Broadcast) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = TensorViewBuilder() + .dtype(DataType::Float) + .contiguity({std::nullopt, true}) + .shape({1, -1}) + .build(); + in->setDeviceMesh(mesh); + in->axis(0)->parallelize(ParallelType::DIDx); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + FusionExecutorCache fec(std::move(fusion)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({1, 8}, options); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); +} + +TEST_F(MultideviceShardingTest, ExpandedBroadcast) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = TensorViewBuilder() + .dtype(DataType::Float) + .contiguity({std::nullopt, true}) + .shape({3, -1}) + .expanded({true, false}) + .build(); + in->setDeviceMesh(mesh); + in->axis(0)->parallelize(ParallelType::DIDx); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + FusionExecutorCache fec(std::move(fusion)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor in_tensor = at::randn({8}, options).as_strided({3, 8}, {0, 1}); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); +} + } // namespace nvfuser From 02f6514ae25637e89efb4d2317299c52e9d574dc Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 1 Sep 2024 19:02:20 -0700 Subject: [PATCH 35/54] Fix the StagedReduction test when `-np 1`. (#2878) --- csrc/codegen.cpp | 14 ++-- .../analysis/thread_predicate.cpp | 8 +- tests/cpp/test_multidevice_pipeline.cpp | 80 ++++++++++--------- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 5a23c47dc34..82b6c91a99e 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -579,11 +579,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } void handle(const kir::TensorIndex* ti) final { - bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global && - kernel_->summary().sync_map->needsRawSync(ti->view()).hasBID(); - bool is_pointer = isPointerType(ti->index()->dtype()); - if (is_pointer) { - bool is_u32_ptr = ti->index()->dtype() == DataType::SMemAddress; + if (isPointerType(ti->index()->dtype())) { + const bool is_u32_ptr = ti->index()->dtype() == DataType::SMemAddress; if (is_u32_ptr) { // DataType::SMemAddress is implemented as uint32_t in C++. The problem // for this implementation is, the type promotion rule in C++ for @@ -598,10 +595,13 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } return; } - bool different_dtype = ti->view()->dtype() != ti->dtype(); - if (is_volatile) { + + if (ti->view()->getMemoryType() == MemoryType::Global && + kernel_->summary().sync_map->needsRawSync(ti->view()).hasBID()) { code_ << "*(volatile " << ti->getDataType().value() << "*)&"; } + + const bool different_dtype = ti->view()->dtype() != ti->dtype(); if (different_dtype) { code_ << "(*reinterpret_cast<" << ti->getDataType().value() << "*>(&"; } diff --git a/csrc/device_lower/analysis/thread_predicate.cpp b/csrc/device_lower/analysis/thread_predicate.cpp index 1ade9c0cdc1..408140590b1 100644 --- a/csrc/device_lower/analysis/thread_predicate.cpp +++ b/csrc/device_lower/analysis/thread_predicate.cpp @@ -842,11 +842,15 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains( const bool output_smem = tv->getMemoryType() == MemoryType::Shared; for (auto id : iter_domains) { - if (!id->isBroadcast() || - !GpuLower::current()->concretizedBroadcastDomains()->isConcretized( + if (!id->isBroadcast()) { + continue; + } + + if (!GpuLower::current()->concretizedBroadcastDomains()->isConcretized( id)) { continue; } + if (id->isBlockDim() || (!output_smem && id->isThreadDim())) { parallel_broadcast.set(id->getParallelType()); } diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 8a161029df7..cf2e7952d79 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -7,6 +7,11 @@ // clang-format on #include +#include +#include + +#include + #include #include #include @@ -15,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -29,18 +35,11 @@ #include #include #include -#include #include #include -#include -#include - namespace nvfuser { -using namespace torch::jit::fuser::cuda; -using namespace at::indexing; - // To run the following tests on several devices, pytorch must be installed with // the flag USE_DISTRIBUTED=1 and nccl support. With that, nvFuser is built by // default with NVFUSER_DISTRIBUTED defined. Then, on a node with at least 6 @@ -375,31 +374,33 @@ class PipelineTestStagedReduction public ::testing::WithParamInterface {}; // 1D staged reduction -// Inputs: X[A,B,C] +// Inputs: X[num_devices,B,C] TEST_P(PipelineTestStagedReduction, StagedReduction) { auto scheduling_mode = GetParam(); - int num_devices = communicator_->size(); - int A = num_devices; - int B = 8; - int C = 64; - std::vector unsharded_input_sizes = {A, B, C}; - std::vector input_sizes(unsharded_input_sizes); - input_sizes[0] = 1; + const int num_devices = communicator_->size(); + constexpr int B = 8; + constexpr int C = 64; FusionGuard fg(fusion.get()); - TensorView* tv0 = makeConcreteTensor(unsharded_input_sizes); + // The first dimension is made symbolic so `tv_out->definition()` won't + // become a squeeze when num_devices == 1. This wouldn't be a problem for + // automatic mode. However, for the manual mode, the scheduling code below + // assumes `tv_out->definition()` can be lowered to communication. A squeeze + // can't. + TensorView* tv0 = TensorViewBuilder() + .dtype(DataType::Float) + .contiguity(true) + .shape({-1, B, C}) + .build(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + tv0->setDeviceMesh(mesh); TensorView* tv1 = sum(tv0, {2}); TensorView* tv_out = sum(tv1, {0}); fusion->addInput(tv0); fusion->addOutput(tv_out); - // multi device scheduling: - auto mesh = DeviceMesh::createForNumDevices(num_devices); - for (auto tv : {tv0, tv1, tv_out}) { - tv->setDeviceMesh(mesh); - } - for (auto tv : {tv0, tv1}) { + for (auto* tv : {tv0, tv1}) { tv->axis(0)->parallelize(ParallelType::DIDx); } @@ -430,22 +431,22 @@ TEST_P(PipelineTestStagedReduction, StagedReduction) { // tv1[I0{A}, I1{B}, R2i{32}] = tv3[I0{A}, I1{B}, R2oi{4}, I2i{32}] // clang-format on - // Incrementally, can print in between for debugging - tv0->computeAt(tv2, 2); - tv2->computeAt(tv3, 2); - tv3->computeAt(tv1, 2); - - // Re do it all at once, because why not. - tv0->computeAt(tv1, 2); + // tv1 is a segment boundary so must be in global. This wouldn't be + // needed if the fusion were scheduled automatically. + tv1->setMemoryType(MemoryType::Global); + // Use `tv2` as the reference tensor because it contains the most + // parallel IterDomains. + tv2->axis(1)->parallelize(ParallelType::BIDx); tv2->axis(3)->parallelize(ParallelType::Unroll); - tv1->axis(1)->parallelize(ParallelType::BIDx); - tv1->setMemoryType( - MemoryType::Global); // necessary to avoid runtime error - - tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike( + tv2, + /*pos=*/-1, + // Don't propagate the parallelization to `tv_out` because that's in + // a different, resharding segment. + /*selected_tv=*/{tv0, tv1, tv2, tv3}); + inlineMost(); break; } case SchedulingMode::Automatic: @@ -453,9 +454,12 @@ TEST_P(PipelineTestStagedReduction, StagedReduction) { break; } - unsharded_inputs = {at::randn(unsharded_input_sizes, tensor_options)}; - ref_unsharded_outputs = {at::sum( - unsharded_inputs.at(0).toTensor(), at::OptionalIntArrayRef({0, 2}))}; + at::Tensor unsharded_input_tensor = + at::randn({num_devices, B, C}, tensor_options); + at::Tensor ref_unsharded_output_tensor = + unsharded_input_tensor.sum(at::IntArrayRef({0, 2})); + unsharded_inputs = {unsharded_input_tensor}; + ref_unsharded_outputs = {ref_unsharded_output_tensor}; executeAndValidate(/* validate_with_prescribed_values */ true); } From 744bf544f85f38c47d3f2ceff2111b91fa4a320f Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Mon, 2 Sep 2024 09:17:05 -0400 Subject: [PATCH 36/54] Cache `ir_utils::allTvs` as part of Fusion (#2873) This caches `ir_utils::allTvs(fusion)` as `fusion->allTvs()`. The cache is automatically invalidated whenever the TV graph topology changes; this mechanism is the same one used to recompute `Expr` uses automatically. --------- Co-authored-by: Christian Sarofeen --- csrc/fusion.cpp | 43 +++++++++++++++++++++++++++++++++--------- csrc/fusion.h | 11 ++++++++++- csrc/ir/base_nodes.cpp | 2 +- csrc/kernel_cache.cpp | 6 ++++-- 4 files changed, 49 insertions(+), 13 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index b2fe3bab274..bd864c9f881 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -116,6 +116,14 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_; + if (from->all_tvs_ptr_ != nullptr) { + to->all_tvs_ptr_ = std::make_unique>(); + to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size()); + for (TensorView* from_tv : *from->all_tvs_ptr_) { + to->all_tvs_ptr_->push_back(ir_cloner.clone(from_tv)->as()); + } + } + return ir_cloner; } @@ -168,7 +176,8 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); - all_tv_uses_valid_ = false; + invalidateTvsAndUses(); + is_during_update_uses_ = false; } @@ -179,13 +188,19 @@ void Fusion::removeExpr(Expr* expr) { // we're going with the strictest model which errors. for (auto out : expr->outputs()) { + if (out->isA()) { + invalidateTvsAndUses(); + } out->setDefinition(nullptr); } // Remove uses in inputs for (auto inp : expr->inputs()) { - // Note that if inp is a TensorView, this may call invalidateTvUses + // Note that if inp is a TensorView, this may call invalidateTvsAndUses inp->removeUse(expr); + if (inp->isA()) { + invalidateTvsAndUses(); + } } IrContainer::removeExpr(expr); @@ -228,6 +243,8 @@ void Fusion::removeVal(Val* val) { removeExpr(e); } IrContainer::removeVal(val); + + invalidateTvsAndUses(); } void Fusion::addInput(Val* input) { @@ -250,7 +267,7 @@ void Fusion::addInput(Val* input) { inputs_.push_back(input); input->setIsFusionInput(true); - all_tv_uses_valid_ = false; + invalidateTvsAndUses(); } void Fusion::addOutputInternal(Val* output) { @@ -264,7 +281,7 @@ void Fusion::addOutputInternal(Val* output) { outputs_.push_back(output); output->setIsFusionOutput(true); - all_tv_uses_valid_ = false; + invalidateTvsAndUses(); } void Fusion::addOutput(Val* output) { @@ -290,7 +307,7 @@ void Fusion::removeInput(Val* input) { inputs_.erase(find_input); } input->setIsFusionInput(false); - all_tv_uses_valid_ = false; + invalidateTvsAndUses(); } void Fusion::removeOutput(Val* output) { @@ -299,7 +316,7 @@ void Fusion::removeOutput(Val* output) { outputs_.erase(find_output); } output->setIsFusionOutput(false); - all_tv_uses_valid_ = false; + invalidateTvsAndUses(); } void Fusion::replaceOutput(Val* output, Val* replacement) { @@ -326,7 +343,7 @@ void Fusion::replaceOutput(Val* output, Val* replacement) { } } // Mark uses invalid so that they will be reset next time uses() is called - invalidateTvUses(); + invalidateTvsAndUses(); } // Temporary WAR for issue #1112 @@ -582,7 +599,7 @@ void Fusion::registerExpr(Expr* expr) { // Don't just add this expr as a use of the input if it's a tensor as the // whole fusion needs to be traversed to rebuild the usage lists if (input->isA()) { - invalidateTvUses(); + invalidateTvsAndUses(); } else { input->addUse(expr); } @@ -605,7 +622,7 @@ void Fusion::registerExpr(Expr* expr) { // If that happens, our definition-based traversal can change and // introduce whole new branches, so we need to recompute the uses_ // vector after setDefinition. - invalidateTvUses(); + invalidateTvsAndUses(); } } } @@ -854,4 +871,12 @@ bool isExpressionEvaluated(Fusion* fusion) { }); } +std::vector Fusion::allTvs() { + if (all_tvs_ptr_ == nullptr) { + all_tvs_ptr_ = + std::make_unique>(ir_utils::allTvs(this)); + } + return std::vector(*all_tvs_ptr_); +} + } // namespace nvfuser diff --git a/csrc/fusion.h b/csrc/fusion.h index 72280c384c7..b640fc02fb5 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -429,6 +429,12 @@ class NVF_API Fusion : public IrContainer { expected_dynamic_smem_bytes_ = bytes; } + //! This is a cached version of ir_utils::allTvs that is invalidated. Return a + //! copy of the vector instead of a reference as it can be invalidated by many + //! operations. If we returned a reference and are iterating on it while + //! making modifications to the fusion, it can easily cause a segfault. + std::vector allTvs(); + protected: friend SegmentCandidateFinder; friend SegmentedFusion; @@ -456,8 +462,9 @@ class NVF_API Fusion : public IrContainer { //! Declare that TensorView uses need to be updated (but don't actually do //! the update). - void invalidateTvUses() { + void invalidateTvsAndUses() { all_tv_uses_valid_ = false; + all_tvs_ptr_.reset(); } private: @@ -485,6 +492,8 @@ class NVF_API Fusion : public IrContainer { // If set to a non-negative value during scheduling, this will be checked by // the executor. int64_t expected_dynamic_smem_bytes_ = -1LL; + + std::unique_ptr> all_tvs_ptr_ = nullptr; }; // Returns true if all fusion outputs are expression evaluated. diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 5c22294b698..43c2b28b01c 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -110,7 +110,7 @@ bool Val::removeUse(Expr* expr) { uses_.erase(it); if (this->isA()) { // Call for a rebuild of uses_ vector - fusion()->invalidateTvUses(); + fusion()->invalidateTvsAndUses(); } return true; } diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index af2ecdc1d44..8e8c6be4f60 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -988,7 +988,8 @@ FusionKernelRuntime::FusionKernelRuntime( // SchedulerRuntimeInfo modifies the fusion, so it is required for both // compile paths. - std::vector all_tvs = ir_utils::allTvs(fusion.get()); + std::vector all_tvs = + fusion->allTvs(); // ir_utils::allTvs(fusion.get()); SchedulerRuntimeInfo runtime_info( fusion.get(), args, nullptr, all_tvs, forced_index_type); @@ -1453,7 +1454,8 @@ std::optional FusionKernelRuntime:: // Get all tensorviews for segmented fusion std::vector all_tvs_for_fusion_to_run = - ir_utils::allTvs(fusion_to_run); + fusion_to_run->allTvs(); + // ir_utils::allTvs(fusion_to_run); SchedulerRuntimeInfo fusion_to_run_info( fusion_to_run, From e46465340100b136d51455d86c86aadddb790947 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 2 Sep 2024 08:44:00 -0700 Subject: [PATCH 37/54] Remove unnecessary calls to `.ndims`. (#2882) --- tests/cpp/test_alias.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 0a8a427a6b8..b68b3678023 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -838,7 +838,6 @@ TEST_F(AliasTest, MergeTwoExpandedBroadcasts) { FusionGuard fg(fusion.get()); TensorView* in = TensorViewBuilder() - .ndims(3) .dtype(DataType::Float) .contiguity({std::nullopt, std::nullopt, std::nullopt}) .shape({4, 5, 6}) @@ -862,7 +861,6 @@ TEST_F(AliasTest, MergeBroadcastsBetweenConcretes) { FusionGuard fg(fusion.get()); TensorView* in = TensorViewBuilder() - .ndims(4) .dtype(DataType::Float) .contiguity({true, std::nullopt, std::nullopt, true}) .shape({2, 3, 5, 7}) From fabbc1aae524332a4ac9607fb952dfe93dde53a1 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:31:41 -0400 Subject: [PATCH 38/54] Cache exact map to speed up ExpressionEvaluator ctor (#2872) This just precomputes an `ExactLogicalDomainMap` for the dynamic fusion in `FusionExecutorCache`, so that we don't need to re-build that map every time we call `propagateBoundValuesThroughExactMaps`. This speeds up the dynamic shape overhead due to building `DynamicTransformConcretizationInfo` considerably. In the tested case the `getKernelRuntimeFor` overhead is reduced from 834 us to 648 us with this PR, a reduction of 22%. --- csrc/dynamic_transform.cpp | 7 +++++-- csrc/dynamic_transform.h | 4 +++- csrc/expr_evaluator.cpp | 11 +++++++++-- csrc/expr_evaluator.h | 5 ++++- csrc/kernel_cache.cpp | 6 ++++-- csrc/kernel_cache.h | 4 ++++ 6 files changed, 29 insertions(+), 8 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 726555769a0..2e3a7a3d49d 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -252,7 +253,8 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo( const DynamicTransformInitialInfo* initial_info, - ExpressionEvaluator* expr_eval) + ExpressionEvaluator* expr_eval, + ExactLogicalDomainMap* exact_map) : initial_info_(initial_info) { NVF_ERROR( !fusion()->isA(), @@ -260,7 +262,8 @@ DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo( // Make sure all exactly mapped IDs have the same value in the // evaluator when any one of the IDs has a known value - expr_eval->propagateBoundValuesThroughExactMaps(initial_info_->fusion()); + expr_eval->propagateBoundValuesThroughExactMaps( + initial_info_->fusion(), exact_map); analyzeReshapes(expr_eval); diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 98376007384..e5b25f0dab7 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -141,7 +142,8 @@ class DynamicTransformConcretizationInfo { public: NVF_API DynamicTransformConcretizationInfo( const DynamicTransformInitialInfo* initial_info, - ExpressionEvaluator* expr_eval); + ExpressionEvaluator* expr_eval, + ExactLogicalDomainMap* exact_map = nullptr); //! Return a vector of integers each corresponding to the position in //! initialInfo()->getMaybeZeroExtents() of an extent Val which is guaranteed diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 85476396c6a..764ea83aa50 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -352,11 +352,18 @@ void ExpressionEvaluator::print() const { debug() << "--------------------\n\n"; } -void ExpressionEvaluator::propagateBoundValuesThroughExactMaps(Fusion* fusion) { +void ExpressionEvaluator::propagateBoundValuesThroughExactMaps( + Fusion* fusion, + ExactLogicalDomainMap* exact_map) { // We map Symbolic IterDomains here only if their extents match. This avoids // mapping between symbolic domains that might concretize to an (Iteration, // Broadcast) pair from a resolved broadcast. - const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); + std::unique_ptr exact_map_ptr; + if (exact_map == nullptr) { + exact_map_ptr = std::make_unique(fusion); + exact_map = exact_map_ptr.get(); + } + const auto mapped_sets = exact_map->getMappedSets(); for (const auto& set : mapped_sets.disjointSets()) { int64_t known_size = -1; diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index 547d4cd3ba4..ef1114bb8ff 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -90,7 +91,9 @@ class ExpressionEvaluator { //! root IDs that are exactly mapped also get bound to the same //! value. This is currently just done with ExactLogicalDomainMap, but //! can be similarly done with the Exact CA map as well. - void propagateBoundValuesThroughExactMaps(Fusion* fusion); + void propagateBoundValuesThroughExactMaps( + Fusion* fusion, + ExactLogicalDomainMap* exact_map = nullptr); ExpressionEvaluator clone(IrCloner& ir_cloner) const; diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 8e8c6be4f60..6402ed56620 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -440,6 +441,7 @@ FusionExecutorCache::FusionExecutorCache( int64_t fusion_id, bool auto_schedule) : fusion_(std::move(fusion)), + exact_map_(std::make_unique(fusion_.get())), fusion_id_{fusion_id}, auto_schedule_(auto_schedule) {} @@ -689,7 +691,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( auto expr_eval = executor_utils::bindInputs(args, fusion_.get()); cached_conc_info_.emplace_back( std::make_unique( - &initial_info, &expr_eval)); + &initial_info, &expr_eval, exact_map_.get())); conc_info = cached_conc_info_.back().get(); } @@ -886,7 +888,7 @@ void FusionExecutorCache::deserialize( auto expr_eval = executor_utils::bindInputs(args, fusion_.get()); cached_conc_info_.emplace_back( std::make_unique( - &initial_info, &expr_eval)); + &initial_info, &expr_eval, exact_map_.get())); conc_info = cached_conc_info_.back().get(); } diff --git a/csrc/kernel_cache.h b/csrc/kernel_cache.h index 6562cc7fb4a..8f41fa9cade 100644 --- a/csrc/kernel_cache.h +++ b/csrc/kernel_cache.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -732,6 +733,9 @@ class FusionExecutorCache { //! concretization info) pair std::vector deterministic_conc_info_; + //! This is cached to speed up finding concretization info + std::unique_ptr exact_map_; + //! Logging state for most recent compilation bool profiling_ = false; From dca416dc2d5b1dcfb37202ab373f08fb55023c10 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:21:37 -0400 Subject: [PATCH 39/54] Add more comments around getKernelRuntimeFor (#2871) Co-authored-by: Christian Sarofeen Co-authored-by: Jingyue Wu --- csrc/kernel_cache.cpp | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 6402ed56620..3c83ed5642a 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -661,10 +661,39 @@ DynamicTransformInitialInfo& FusionExecutorCache::initialInfo() { return initial_info_.value(); } +// getKernelRuntimeFor inspects the inputs to find a usable FusionKernelRuntime +// as quickly as possible. To do so we cache at multiple levels: +// A. If we have seen these inputs before, we re-use the FusionKernelRuntime +// we used last time. Here, we mean the same input tensor sizes, as well as +// same input scalars if they are used to compute an intermediate or output +// tensor size. +// B. We check how we should concretize the dynamic fusion using these +// inputs. If we have not concretized the fusion this way previously, then we +// concretize it and create a new FusionKernelRuntime, which means segmenting +// and compiling new kernels. Otherwise, we check whether we can re-use any of +// the previously-segmented runtimes. +// i. We look at all FusionKernelRuntimes that have been used with +// this concretized fusion. +// ii. For each of those runtimes, we compare the heuristic parameters for +// each segment to those that we compute using the current inputs. +// If we do not find any runtimes whose heuristic parameters match, then we +// create a new FusionKernelRuntime, which means segmenting and compiling all +// new kernels. +// +// In summary, we have the following paths, in order of hottest to coldest: +// 1. Input ID cache hit: re-use runtime used last time these inputs were seen +// 2. Concretization match, runtime heuristic params match: re-use runtime +// after checking concretization/heuristics. +// 3. Concretization match but no runtime heuristic params match. Segment +// to create new FusionKernelRuntime +// 4. Concretization is unseen: Segment to create a new FusionKernelRuntime +// For re-used shapes, path 1 is most relevant. For dynamic shape problems with +// a large number of unique shapes, path 2 is important. Paths 3 and 4 are slow +// since they both involve re-segmentation and re-compilation of the Fusion. FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( const KernelArgumentHolder& args, std::optional forced_index_type) { - // Check for id hit case + // Check for id hit case (Path 1) auto unique_id_opt = args.getCacheId(); NVF_CHECK( unique_id_opt.has_value(), @@ -696,7 +725,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( } // Initialize or fetch vector of FusionKernelRuntime objects associated with - // each pair of device ID and + // each pair of device ID and concretization info. auto config = std::make_pair(args.getDeviceIndex(), conc_info); auto& kernel_runtimes = kernel_runtimes_.try_emplace(config).first->second; auto result = @@ -712,13 +741,17 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( FusionKernelRuntime* kernel_runtime = nullptr; + // reusing indicates whether we are following Path 2 (true) or Paths 3/4 + // (false) bool reusing = false; // By default, we try to avoid recompiling whenever possible. However, this // can lead to suboptimal code if we only check that a compiled kernel is able // to run with some inputs, instead of whether it is optimal to do so. The // NVFUSER_DISABLE=kernel_reuse option is a coarse tool that just enforces // that whenever we encounter a new set of input shapes we segment and compile - // a new FusionKernelRuntime. + // a new FusionKernelRuntime. Effectively, this option disables Paths 2 and 3 + // above so that we only have Path 1 (hottest re-use path) and Path 4 (full + // recompile). if (!isOptionDisabled(DisableOption::KernelReuse)) { auto reuse_it = std::find_if( kernel_runtimes.begin(), @@ -740,6 +773,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( } if (!reusing) { + // Paths 3 or 4 // cache miss, need to re-build an optimized graph for this case // Clone fusion_ so that we can safely use an ExpressionEvaluator on it, for From f669fcf78b5c5dee6c08715aeb6f2e36a6af964b Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 3 Sep 2024 00:51:14 -0400 Subject: [PATCH 40/54] Cleanup all uses of ir_utils::allTvs (#2884) Follow up to https://github.com/NVIDIA/Fuser/pull/2873 removes all uses of the ir_utils variant in favor of the Fusion variant. --------- Co-authored-by: Jacob Hinkle --- csrc/compute_at_map.cpp | 6 ++-- .../device_lower/analysis/divisible_split.cpp | 2 +- .../analysis/thread_predicate.cpp | 2 +- csrc/device_lower/lower2device.cpp | 2 +- csrc/device_lower/pass/expr_sort.cpp | 2 +- csrc/device_lower/pass/loops.cpp | 4 +-- csrc/device_lower/validation.cpp | 4 +-- csrc/evaluator_common.cpp | 2 +- csrc/fusion.cpp | 27 +++++++++++++++-- csrc/fusion_segmenter.cpp | 10 +++---- csrc/id_model/validation_utils.cpp | 2 +- csrc/inlining.cpp | 4 +-- csrc/ir/graphviz.cpp | 2 +- csrc/ir/utils.cpp | 23 ++------------ csrc/ir/utils.h | 3 -- csrc/kernel_cache.cpp | 4 +-- csrc/multidevice/utils.cpp | 4 +-- csrc/preseg_passes/mark_aliases_prepare.cpp | 4 +-- csrc/preseg_passes/propagate_shardings.cpp | 2 +- csrc/python_frontend/fusion_definition.cpp | 4 +-- csrc/scheduler/normalization_inner_outer.cpp | 2 +- csrc/scheduler/pointwise.cpp | 2 +- csrc/scheduler/registry.cpp | 2 +- csrc/scheduler/registry_utils.cpp | 2 +- csrc/scheduler/utils.cpp | 18 +++++------ tests/cpp/test_gpu3.cpp | 12 ++++---- tests/cpp/test_gpu_compute_with.cpp | 4 +-- tests/cpp/test_gpu_fused_reduction.cpp | 5 ++-- tests/cpp/test_gpu_outer_reduction.cpp | 5 ++-- tests/cpp/test_gpu_utils.cpp | 2 +- tests/cpp/test_id_model.cpp | 30 +++++++++---------- tests/cpp/test_indexing.cpp | 16 +++++----- tests/cpp/test_matmul.cpp | 6 ++-- tests/cpp/test_scatter_gather.cpp | 2 +- tests/cpp/utils.h | 2 +- 35 files changed, 110 insertions(+), 113 deletions(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 0176ec7563a..3b159e696c5 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -292,7 +292,7 @@ std::optional> detectMappablePair( // matter in practice. std::optional> findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { // For each tensor, make sure root, logical and loop domains // should not include domains that are mapped with another domain // in the same set of domains. This may be overly conservative, @@ -342,7 +342,7 @@ void IterDomainGraph::build(Fusion* fusion) { FusionGuard fg(fusion); // Initialize a node for every iteration domain - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { const auto& domain = tv->getLoopDomain(); auto all_ids = tv->domain()->allIDs(); @@ -586,7 +586,7 @@ void IterDomainGraph::build(Fusion* fusion) { // transformations makes it easy to check if different view operations are // consistent with eachother. - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); std::vector all_consumer_tvs; std::copy_if( all_tvs.begin(), diff --git a/csrc/device_lower/analysis/divisible_split.cpp b/csrc/device_lower/analysis/divisible_split.cpp index cbf251e5d35..a4844d8c388 100644 --- a/csrc/device_lower/analysis/divisible_split.cpp +++ b/csrc/device_lower/analysis/divisible_split.cpp @@ -25,7 +25,7 @@ std::unordered_set getAllDivisibleSplits( const ComputeAtMap* ca_map) { std::unordered_set all_divisible_splits; - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); // Find all tensor views with a view like rfactor. Splits used in view // transformations must be divisible by definition. for (auto tv : all_tvs) { diff --git a/csrc/device_lower/analysis/thread_predicate.cpp b/csrc/device_lower/analysis/thread_predicate.cpp index 408140590b1..0c9b4413c9b 100644 --- a/csrc/device_lower/analysis/thread_predicate.cpp +++ b/csrc/device_lower/analysis/thread_predicate.cpp @@ -734,7 +734,7 @@ void ThreadPredicateMap::build(Fusion* fusion) { updateBitSet(expr); } - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (tv->getMemoryType() == MemoryType::Global) { avoidConcretizedBroadcastRedundantWrite(tv); } diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 03d5834c0e8..3aa025b4f8e 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -345,7 +345,7 @@ bool requiresIdModel(Fusion* fusion) { } // If a tensor does not have a nice root->logical/allocation->loop // linear transformation history, use IdModel. - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (!lower_utils::hasRootToLoopLinearTransformations(tv)) { return true; } diff --git a/csrc/device_lower/pass/expr_sort.cpp b/csrc/device_lower/pass/expr_sort.cpp index bb15739af23..b7144d3577d 100644 --- a/csrc/device_lower/pass/expr_sort.cpp +++ b/csrc/device_lower/pass/expr_sort.cpp @@ -1142,7 +1142,7 @@ void ExprSegmentationSorter::initializeForLoopDependencies() { concrete_id_dependencies_.empty(), "For loop dependencies have already been initialized."); - for (auto tv : ir_utils::allTvs(fusion_)) { + for (auto tv : fusion_->allTvs()) { std::unordered_set dependencies; for (int64_t tv_id_i = std::max( tv->getMaxProducerPosition(), diff --git a/csrc/device_lower/pass/loops.cpp b/csrc/device_lower/pass/loops.cpp index bd8c1a60271..30e30e063bb 100644 --- a/csrc/device_lower/pass/loops.cpp +++ b/csrc/device_lower/pass/loops.cpp @@ -145,7 +145,7 @@ void LoopNestGenerator::generate(const std::vector& exprs) { std::unordered_map> concrete_id_dependencies; - for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) { + for (auto tv : FusionGuard::getCurFusion()->allTvs()) { std::unordered_set dependencies; for (auto tv_id : tv->getLoopDomain()) { @@ -212,7 +212,7 @@ void LoopNestGenerator::generate(const std::vector& exprs) { } // Generate loop structure for each tensor view - for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) { + for (auto tv : FusionGuard::getCurFusion()->allTvs()) { // Zero dim tensor support if (tv->nDims() == 0) { loop_structures_[tv] = std::vector(); diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 5091e655556..2b465405406 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -189,7 +189,7 @@ void validateIr(Fusion* fusion) { "Tensor with dynamic transform must be concretized before lowering: ", toDelimitedString(dynamic_tvs.begin(), dynamic_tvs.end())); - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); validateCpAsyncBulk(all_tvs); } @@ -912,7 +912,7 @@ void validateSwizzle(Fusion* fusion) { } void validateAndConvertIterDomainGrouping(Fusion* fusion) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { bool is_grouped = false; for (const auto id_idx : c10::irange(tv->nDims())) { const auto id = tv->axis(id_idx); diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index c7a45c79359..56d9c868f90 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -98,7 +98,7 @@ void collectBufferSizes( std::vector collectRuntimeUsedValues(Fusion* fusion) { std::vector ret; - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); // Collect extent and inputs for (auto tv : all_tvs) { for (auto id : tv->getLoopDomain()) { diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index bd864c9f881..222a3b1afb6 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -871,10 +871,33 @@ bool isExpressionEvaluated(Fusion* fusion) { }); } +namespace { +std::vector findAllTvs(Fusion* fusion) { + auto used_vals = fusion->usedMathVals(); + auto used_tvs = ir_utils::filterByType(used_vals); + + // This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing + // is having an input disconnected from outputs, and these iter domains are + // being checked in compute at maps in scheduling logic. This shouldn't hurt + // AFAICT. + auto tv_inputs = ir_utils::filterByType(fusion->inputs()); + + std::vector all_tvs({used_tvs.begin(), used_tvs.end()}); + // Sometimes inputs are not connected to outputs, however, we still include + // them when returning allTvs because they are registered as an input. + all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end()); + + VectorOfUniqueEntries unique_vector( + all_tvs.begin(), all_tvs.end()); + + // all_tvs has duplicates, to deduplicate it and return + return unique_vector.vector(); +} +} // namespace + std::vector Fusion::allTvs() { if (all_tvs_ptr_ == nullptr) { - all_tvs_ptr_ = - std::make_unique>(ir_utils::allTvs(this)); + all_tvs_ptr_ = std::make_unique>(findAllTvs(this)); } return std::vector(*all_tvs_ptr_); } diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index ec10307bbe3..8ef0319b12c 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -2370,7 +2370,7 @@ class FusionSegmentGuard : public NonCopyable { NVF_ERROR(fusion_ != nullptr); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG narrowToNewSegment(inputs, outputs); } @@ -2382,7 +2382,7 @@ class FusionSegmentGuard : public NonCopyable { FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard"); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG lowered_edges_ = segmented_fusion_->castInputOutputToLowerPrecision( segmented_fusion_->edges()); @@ -2398,7 +2398,7 @@ class FusionSegmentGuard : public NonCopyable { FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard"); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG // Cast inputs and outputs of a merged group consisting of a and @@ -2427,7 +2427,7 @@ class FusionSegmentGuard : public NonCopyable { FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard"); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG // Cast inputs and outputs of a merged group consisting of @@ -2468,7 +2468,7 @@ class FusionSegmentGuard : public NonCopyable { num_original_exprs_, ", actual: ", num_current_exprs); - auto current_tvs = ir_utils::allTvs(fusion_); + auto current_tvs = fusion_->allTvs(); NVF_ERROR( original_tvs_ == current_tvs, "Failed to revert temporary changes."); #endif diff --git a/csrc/id_model/validation_utils.cpp b/csrc/id_model/validation_utils.cpp index c8c116df8e8..6dd6e520f7c 100644 --- a/csrc/id_model/validation_utils.cpp +++ b/csrc/id_model/validation_utils.cpp @@ -118,7 +118,7 @@ bool exprsMap( IdModelValidator::IdModelValidator(Fusion* fusion, bool allow_self_mapping) : ca_map_(fusion, allow_self_mapping) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto id : tv->domain()->allIDs()) { if (id->definition() && id->definition()->isA()) { has_swizzle_ = true; diff --git a/csrc/inlining.cpp b/csrc/inlining.cpp index e308183cc10..d71fc059846 100644 --- a/csrc/inlining.cpp +++ b/csrc/inlining.cpp @@ -29,7 +29,7 @@ void MaxPosCalculator::buildUnmappableDims(bool compute_at_only) { } ComputeAtLogicalDomainMap logical_map; logical_map.build(); - auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); + auto all_tvs = FusionGuard::getCurFusion()->allTvs(); for (auto tv : all_tvs) { auto consumers = ir_utils::consumerTvsOf(tv); for (auto consumer : consumers) { @@ -173,7 +173,7 @@ size_t MaxPosCalculator::getMaxPosAll( } void inlineMost(const std::unordered_set& uninlinable_ids) { - inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids); + inlineMost(FusionGuard::getCurFusion()->allTvs(), uninlinable_ids); } void inlineMost( diff --git a/csrc/ir/graphviz.cpp b/csrc/ir/graphviz.cpp index 4e7413eb148..7cbd23f7dd3 100644 --- a/csrc/ir/graphviz.cpp +++ b/csrc/ir/graphviz.cpp @@ -426,7 +426,7 @@ void TransformToDot::handle(Fusion* fusion) { // Make sure the loop domains are ordered correctly indent() << "graph [ordering=\"out\"];\n"; - for (const auto tv : ir_utils::allTvs(fusion)) { + for (const auto tv : fusion->allTvs()) { handle(tv); } diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 5d52a898e84..ebdcf699f33 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -403,25 +403,6 @@ std::vector outputTvsOf(std::vector tvs) { return uniqueEntries(out_tvs); } -std::vector allTvs(Fusion* fusion) { - auto used_vals = fusion->usedMathVals(); - auto used_tvs = ir_utils::filterByType(used_vals); - - // This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing - // is having an input disconnected from outputs, and these iter domains are - // being checked in compute at maps in scheduling logic. This shouldn't hurt - // AFAICT. - auto tv_inputs = ir_utils::filterByType(fusion->inputs()); - - std::vector all_tvs({used_tvs.begin(), used_tvs.end()}); - // Sometimes inputs are not connected to outputs, however, we still include - // them when returning allTvs because they are registered as an input. - all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end()); - - // all_tvs has duplicates, to deduplicate it and return - return uniqueEntries(all_tvs); -} - VectorOfUniqueEntries allTvsOfExprs( const std::vector& exprs) { VectorOfUniqueEntries all_tvs; @@ -438,7 +419,7 @@ VectorOfUniqueEntries allTvsOfExprs( std::vector allTvsExcept( Fusion* fusion, const std::unordered_set& except) { - auto all_tvs = allTvs(fusion); + auto all_tvs = fusion->allTvs(); std::vector result; for (auto tv : all_tvs) { if (except.count(tv) == 0) { @@ -803,7 +784,7 @@ bool hasResizedRfactor(const TensorView* tv) { } std::vector getTVsWithDynamicTransform(Fusion* fusion) { - const auto all_tvs = ir_utils::allTvs(fusion); + const auto all_tvs = fusion->allTvs(); std::vector dynamic_tvs; std::copy_if( all_tvs.begin(), diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 913df3773f4..46225feb240 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -383,9 +383,6 @@ std::vector inputTvsOf(std::vector tvs); // Returns consumers of tvs that are outputs of fusion std::vector outputTvsOf(std::vector tvs); -// returns all tensor views in fusion that are used between outputs and inputs. -NVF_API std::vector allTvs(Fusion* fusion); - // returns all tensor views used in the provided expressions VectorOfUniqueEntries allTvsOfExprs( const std::vector& exprs); diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 3c83ed5642a..90d498103e2 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -1024,8 +1024,7 @@ FusionKernelRuntime::FusionKernelRuntime( // SchedulerRuntimeInfo modifies the fusion, so it is required for both // compile paths. - std::vector all_tvs = - fusion->allTvs(); // ir_utils::allTvs(fusion.get()); + std::vector all_tvs = fusion->allTvs(); SchedulerRuntimeInfo runtime_info( fusion.get(), args, nullptr, all_tvs, forced_index_type); @@ -1491,7 +1490,6 @@ std::optional FusionKernelRuntime:: // Get all tensorviews for segmented fusion std::vector all_tvs_for_fusion_to_run = fusion_to_run->allTvs(); - // ir_utils::allTvs(fusion_to_run); SchedulerRuntimeInfo fusion_to_run_info( fusion_to_run, diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index b5b8c7f1725..9b32ba9c690 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -233,7 +233,7 @@ void shardAllLike(TensorView* ref, std::vector tvs) { int64_t requestedNumberOfDevices(Fusion* fusion) { DeviceIdxType max_index = 0; - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (tv->hasDeviceMesh()) { for (auto d_id : tv->getDeviceMesh().vector()) { max_index = std::max(max_index, d_id); @@ -253,7 +253,7 @@ void unshard(TensorView* tv) { } void unshard(Fusion* fusion) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { unshard(tv); } } diff --git a/csrc/preseg_passes/mark_aliases_prepare.cpp b/csrc/preseg_passes/mark_aliases_prepare.cpp index 2478105e33e..6afdbca299c 100644 --- a/csrc/preseg_passes/mark_aliases_prepare.cpp +++ b/csrc/preseg_passes/mark_aliases_prepare.cpp @@ -56,7 +56,7 @@ std::unordered_set exprsDependedByNonAliases( const AliasAnalysisResult& analysis, Fusion* fusion) { std::vector non_aliases; - for (TensorView* tv : ir_utils::allTvs(fusion)) { + for (TensorView* tv : fusion->allTvs()) { if (analysis.getRoot(tv) == nullptr) { non_aliases.push_back(tv); } @@ -129,7 +129,7 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) { } // Materialize the alias-enabling allocation domain. - for (TensorView* tv : ir_utils::allTvs(fusion)) { + for (TensorView* tv : fusion->allTvs()) { if (analysis.getRoot(tv) == nullptr) { continue; } diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index e3e4f39f8d9..997566cfaf1 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -42,7 +42,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Validate that meshes are assigned to all TensorViews or none. TensorView* tv_with_mesh = nullptr; TensorView* tv_without_mesh = nullptr; - for (TensorView* tv : ir_utils::allTvs(fusion)) { + for (TensorView* tv : fusion->allTvs()) { auto update_if_null = [](TensorView*& lhs, TensorView* rhs) { if (lhs == nullptr) { lhs = rhs; diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 5f71d7a4604..d6926065d67 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -230,7 +230,7 @@ void FusionDefinition::setupSchedule(const at::ArrayRef& inputs) { user_schedule_fusion, args, /*precomuted_values=*/nullptr, - ir_utils::allTvs(user_schedule_fusion)); + user_schedule_fusion->allTvs()); // Manually setting the fusion guard as there is not a good way of using a // guard in a local scope across the schedule function @@ -243,7 +243,7 @@ void FusionDefinition::finalizeSchedule( FUSER_PERF_SCOPE("FusionDefinition::finalizeSchedule"); // TODO: remove when multidevice executor integration is done natively Fusion* fusion = user_sched_->schedule.get(); - std::vector tvs = ir_utils::allTvs(fusion); + std::vector tvs = fusion->allTvs(); if (std::any_of(tvs.begin(), tvs.end(), [](Val* v) { return v->isA() && v->as()->hasDeviceMesh(); })) { diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index a29cb3af26a..51562346e72 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -238,7 +238,7 @@ std::vector getOuterBroadcastTvs( // find the broadcast tensor whose broadcast mask is same to the reference std::vector outer_broadcast_tvs; - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (std::any_of( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index df6ba64499e..2445ab79afa 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -901,7 +901,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // unrolling manually. inlineAllAt(reference_tv, unswitch_pos, true); - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); // Inline at the inner most position. The CA position of all tensors except // inputs, cached inputs and outputs will be updated. diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 9800576758d..047d9e479b1 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -39,7 +39,7 @@ SchedulerRuntimeInfo::SchedulerRuntimeInfo( } else { index_type_ = registry_utils::getIndexTypeOfKernel( complete_fusion_, - all_tvs.empty() ? ir_utils::allTvs(complete_fusion_) : all_tvs, + all_tvs.empty() ? complete_fusion_->allTvs() : all_tvs, args, *expression_evaluator_); } diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index b12475b208e..2af092c08f2 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -62,7 +62,7 @@ bool checkPatternEquivalence( bool hasNonUniqueBcast(Fusion* fusion) { ConcretizedBroadcastDomains concretize_info(fusion); - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto id : tv->getMaybeRootDomain()) { if (concretize_info.maybeNonUniquelyConcretized(id)) { return true; diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index db40a638068..6013f2838e6 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -311,7 +311,7 @@ void parallelizeAllLike( } if (selected_tvs.empty()) { - selected_tvs = ir_utils::allTvs(reference_tv->fusion()); + selected_tvs = reference_tv->fusion()->allTvs(); } for (auto tv : selected_tvs) { if (tv->isFusionInput()) { @@ -564,7 +564,7 @@ PersistentBufferInfo persistentBuffers(Fusion* fusion) { ComputeAtLogicalDomainMap logical_map; logical_map.build(); - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); for (auto producer : all_tvs) { // Are all producer ids mappable to all consumers @@ -1063,7 +1063,7 @@ std::pair canonicalDimReduction( } std::vector getReductionTvs(Fusion* fusion) { - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); std::vector reduction_tvs; for (auto tv : all_tvs) { if (!tv->isFusionInput() && @@ -1130,7 +1130,7 @@ std::vector getTVsWithNonReductionRFactor(Fusion* fusion) { // Reset inputs and outputs to global memory, everything else to local. void clearMemorySpace(Fusion* fusion) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (tv->isFusionInput() || tv->isFusionOutput()) { tv->setMemoryType(MemoryType::Global); } else { @@ -1986,7 +1986,7 @@ DisjointSets disjointLogicalSets(Fusion* fusion) { // If iter domains are involved in any transformation from root domains to // logical domains they should be considered "contaminated". - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto expr : StmtSort::getExprsTo( {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()})) { if (expr->isA()) { @@ -2146,7 +2146,7 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { // If iter domains are involved in any transformation from root domains to // logical domains they should be considered "contaminated". - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto expr : StmtSort::getExprsBetween( {tv->getMaybeRootDomain().begin(), tv->getMaybeRootDomain().end()}, {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()})) { @@ -2183,7 +2183,7 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { // If iter domains are involved in any transformation from root domains to // logical domains they should be considered "contaminated". - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (!tv->hasRoot()) { continue; } @@ -2249,7 +2249,7 @@ std::vector> getNonPointwiseProducerConsumerPairs(Fusion* fusion) { std::vector> tvs; - for (auto consumer : ir_utils::allTvs(fusion)) { + for (auto consumer : fusion->allTvs()) { if (consumer->isFusionInput()) { continue; } @@ -2570,7 +2570,7 @@ void moveNonConcretizedBroadcastInnermost( } } - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { std::vector broadcast_to_move; for (const auto i : c10::irange(tv->getLoopDomain().size())) { auto loop_id = tv->getLoopDomain().at(i); diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 4cc1ac113f4..eecc11cf03f 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -772,7 +772,7 @@ TEST_F(NVFuserTest, FusionIssue1430_CUDA) { scheduler_utils::parallelizeAllLike(rfactor); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv != tv1 || tv != tv3) { for (auto i : c10::irange(tv->nDims())) { if (isParallelTypeVectorize(tv->axis(i)->getParallelType())) { @@ -2054,7 +2054,7 @@ TEST_F(NVFuserTest, FusionExactLogicalDomainMap_CUDA) { exact_map.toString()); // They must not be mapped with anything else. - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { for (auto logical_id : tv->getLogicalDomain()) { if (logical_id == tv2_bc || logical_id == tv3_bc) { continue; @@ -2167,7 +2167,7 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { cached_input->computeAt(rfactor_tv, 4, ComputeAtMode::BestEffort); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv == cached_input || tv == tv_avg || tv == tv_M2) { continue; } @@ -8535,7 +8535,7 @@ TEST_F(NVFuserTest, MoveNonConcretizedBroadcastInNormalization) { auto ref_outermost = tv7->getLoopDomain().at(0); IdModel id_model(&fusion); const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -8603,7 +8603,7 @@ TEST_F(NVFuserTest, MoveNonConcretizedBroadcastInPointwise) { auto ref_outermost = tv5->getLoopDomain().at(0); IdModel id_model(&fusion); const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -8670,7 +8670,7 @@ TEST_F(NVFuserTest, MoveNonConcretizedBroadcastInReduction) { auto ref_outermost = tv6->getLoopDomain().at(0); IdModel id_model(&fusion); const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } diff --git a/tests/cpp/test_gpu_compute_with.cpp b/tests/cpp/test_gpu_compute_with.cpp index 7abf3e891f8..b2d872308bb 100644 --- a/tests/cpp/test_gpu_compute_with.cpp +++ b/tests/cpp/test_gpu_compute_with.cpp @@ -130,7 +130,7 @@ TEST_F(NVFuserTest, FusionComputeWith1_CUDA) { // Set the global inlining only with the outer axis std::unordered_set uninlinable; - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->nDims() == 2) { uninlinable.insert(tv->axis(1)); } @@ -424,7 +424,7 @@ TEST_F(NVFuserTest, FusionComputeWith6_CUDA) { TransformPropagator propagator(tv3_rf); MaxLogicalDomainInfoSpanningTree(tv3_rf).traverse(&propagator); - scheduler_utils::parallelizeAllLike(tv3_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3_rf, fusion.allTvs()); tv1->axis(-1)->parallelize(ParallelType::Vectorize); tv7->axis(-1)->parallelize(ParallelType::Vectorize); diff --git a/tests/cpp/test_gpu_fused_reduction.cpp b/tests/cpp/test_gpu_fused_reduction.cpp index bf0bb8ec877..e67875f4a1a 100644 --- a/tests/cpp/test_gpu_fused_reduction.cpp +++ b/tests/cpp/test_gpu_fused_reduction.cpp @@ -2085,7 +2085,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce4_CUDA) { tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::TIDx); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->axis(-2)->parallelize(ParallelType::BIDy); tv->axis(-1)->parallelize(ParallelType::TIDy); } @@ -2355,8 +2355,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) { })); transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Serial); - scheduler_utils::parallelizeAllLike( - transform_ref_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(transform_ref_rf, fusion.allTvs()); ParallelType vec_pt = ParallelType::Vectorize; tv1->axis(vec_id)->parallelize(vec_pt); diff --git a/tests/cpp/test_gpu_outer_reduction.cpp b/tests/cpp/test_gpu_outer_reduction.cpp index f6c120c5aba..afbe2eb5d5d 100644 --- a/tests/cpp/test_gpu_outer_reduction.cpp +++ b/tests/cpp/test_gpu_outer_reduction.cpp @@ -101,7 +101,7 @@ TEST_F(OuterReductionTest, GroupedGridWelfordOuterOpt) { ref_rf->axis(3)->parallelize(ParallelType::BIDy); ref_rf->axis(5)->parallelize(ParallelType::TIDy); - scheduler_utils::parallelizeAllLike(ref_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(ref_rf, fusion.allTvs()); tv1->axis(-1)->parallelize(ParallelType::Vectorize); tv3->axis(-1)->parallelize(ParallelType::Group); @@ -552,8 +552,7 @@ void scheduleNormalization(Fusion& fusion, const OuterReductionParams& params) { unswitch_id->parallelize(ParallelType::Serial); } - scheduler_utils::parallelizeAllLike( - reduction_tv_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(reduction_tv_rf, fusion.allTvs()); // Vectorize inputs for (auto input_cache : input_caches) { diff --git a/tests/cpp/test_gpu_utils.cpp b/tests/cpp/test_gpu_utils.cpp index 908272aa1e2..f7bc304e47a 100644 --- a/tests/cpp/test_gpu_utils.cpp +++ b/tests/cpp/test_gpu_utils.cpp @@ -1058,7 +1058,7 @@ TEST_F(VectorizeHelperTest, SpanningTree_CUDA) { auto mapper = vectorize_helper::ContiguousInnerDimensionsMapper::map( out, {out->axis(0), out->axis(1)}); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->name() == 0 || tv->name() == 1) { continue; } diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 6f6f1d72c6f..899fc657a88 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -261,7 +261,7 @@ void checkStep2Results(Fusion* fusion, const IdModelTester& tester) { } }; - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { // If there's no broadcast or it isn't inlined, there's no // promotion if (std::none_of( @@ -591,7 +591,7 @@ TEST_F(IdModelTest, ValGraphStmtSort2) { // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3}, // are not connected - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->merge(0)->split(0, 4); } @@ -674,7 +674,7 @@ TEST_F(IdModelTest, ValGraphStmtSort3) { TEST_F(IdModelTest, ValGraphStmtSort4) { auto fusion = createFusionWithMultipleResolutionPaths(); FusionGuard fg(fusion.get()); - auto all_tvs = ir_utils::allTvs(fusion.get()); + auto all_tvs = fusion->allTvs(); // Since this fusion is not supported by ComputeAtMap, the // validation flag must be false @@ -953,14 +953,14 @@ TEST_F(IdModelTest, LoopPromotion4) { TransformPropagator propagator(tv4); MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->inlineAt(-2); } IdModelTester tester(&fusion); // Verify all tensors with root broadcast have correct resolutions - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { // Skip tensors with no broadcast or non-inlined if (std::none_of( tv->getLogicalDomain().begin(), @@ -1078,7 +1078,7 @@ TEST_F(IdModelTest, LoopPromotion5) { tv2->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(2)->parallelize(ParallelType::TIDx); - auto all_tvs = ir_utils::allTvs(&fusion); + auto all_tvs = fusion.allTvs(); IdModelTester tester(&fusion); @@ -1225,7 +1225,7 @@ TEST_F(IdModelTest, LoopPromotion5) { TEST_F(IdModelTest, LoopPromotion6) { auto fusion = createFusionWithMultipleResolutionPaths(); FusionGuard fg(fusion.get()); - auto all_tvs = ir_utils::allTvs(fusion.get()); + auto all_tvs = fusion->allTvs(); IdModelTester tester(fusion.get()); @@ -1558,7 +1558,7 @@ TEST_F(IdModelTest, LoopPromotion7) { tv2->split(-1, 8); - auto all_tvs = ir_utils::allTvs(&fusion); + auto all_tvs = fusion.allTvs(); IdModelTester tester(&fusion); @@ -1698,7 +1698,7 @@ TEST_F(IdModelTest, LoopPromotion8) { // [2, 4, (3*5//2)*7//4] tv5->inlineAt(2); - auto all_tvs = ir_utils::allTvs(&fusion); + auto all_tvs = fusion.allTvs(); IdModelTester tester(&fusion); @@ -1992,7 +1992,7 @@ TEST_F(IdModelTest, LoopPromotionTwoStepFailureReproSimple) { TransformPropagatorWithCheck propagator(t4); MaxLogicalDomainInfoSpanningTree(t4).traverse(&propagator); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->inlineAt(1); } @@ -2044,7 +2044,7 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { fusion.addOutput(tv11); // Merge all domains except for tv10 and tv11 - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv == tv10 || tv == tv11) { continue; } @@ -2054,7 +2054,7 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { } // Fully inline all tensors up until tv10 - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv == tv9 || tv == tv10 || tv == tv11) { continue; } @@ -2446,7 +2446,7 @@ TEST_F(IdModelTest, LoopPromotionWithViewRFactor1) { // All of the inlined tensors (i.e., all tensors except for the // inputs) should be grouped together. - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -2496,7 +2496,7 @@ TEST_F(IdModelTest, LoopPromotionWithLogicalDomains2) { // All of the inlined tensors (i.e., all tensors except for the // inputs) should be grouped together. - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -2560,7 +2560,7 @@ TEST_F(IdModelTest, LoopPromotionCoverage) { // All tvs except for inptus should be just a 1D tensor and be // promoted to a domain that is exactly mappd with the loop domain // of tv10. - for (const auto tv : ir_utils::allTvs(&fusion)) { + for (const auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index b1eb4ea80e1..9f9d78e26b0 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -561,7 +561,7 @@ TEST_F(IndexingTest, SimplePointwise2) { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3, fusion.allTvs()); // Test shared memory indexing tv2->setMemoryType(MemoryType::Shared); @@ -1044,7 +1044,7 @@ TEST_F(IndexingTest, SimpleBroadcast4) { TransformPropagator propagator(tv4); MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->inlineAt(-2); } @@ -1344,7 +1344,7 @@ TEST_F(IndexingTest, SimpleVectorize) { inlineMost(); - scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv2, fusion.allTvs()); struct GetReference : AbstractGetReference { GetReference(const TensorIndexer& indexer, const IdModel& id_model) @@ -1413,7 +1413,7 @@ TEST_F(IndexingTest, NonInnermostVectorize) { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3, fusion.allTvs()); tv1->axis(2)->parallelize(ParallelType::Vectorize); tv3->axis(2)->parallelize(ParallelType::Vectorize); @@ -1648,7 +1648,7 @@ TEST_F(IndexingTest, InlinedUnroll) { tv4->axis(1)->parallelize(ParallelType::Unroll); - scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv4, fusion.allTvs()); // The CA position of tv2 is 1 as shown below: // @@ -1704,7 +1704,7 @@ TEST_F(IndexingTest, SmemAllocationDomainForTranspose) { } // [I0, I1] -> [(I0/32 * I1/32), (32 * 32) / 4, 4] - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->split(0, 32); tv->split(2, 32); tv->reorder({{1, 2}}); @@ -2834,7 +2834,7 @@ TEST_F(PredicateIndexingTest, SimpleVectorize) { inlineMost(); - scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv2, fusion.allTvs()); // T1_l[ iblockIdx.x9{( ceilDiv(( ceilDiv(i0, 4) ), 128) )}, // ithreadIdx.x10{128}, iV8{4} ] ca_pos( 2 ) T2_g[ iblockIdx.x5{( ceilDiv(( @@ -2904,7 +2904,7 @@ TEST_F(PredicateIndexingTest, NonInnermostVectorize) { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3, fusion.allTvs()); tv1->axis(2)->parallelize(ParallelType::Vectorize); tv3->axis(2)->parallelize(ParallelType::Vectorize); diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 862c1f99fc7..c6da1d7d849 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -2406,7 +2406,7 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogue) { // for prologue and 1 for epilogue. int num_shared_mem_tensors = 0; int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; - for (const auto& tv : ir_utils::allTvs(&fusion)) { + for (const auto& tv : fusion.allTvs()) { if (tv->getMemoryType() == MemoryType::Shared) { num_shared_mem_tensors++; } @@ -2640,7 +2640,7 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueCast) { // for prologue and 1 for epilogue. int num_shared_mem_tensors = 0; int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; - for (const auto& tv : ir_utils::allTvs(&fusion)) { + for (const auto& tv : fusion.allTvs()) { if (tv->getMemoryType() == MemoryType::Shared) { num_shared_mem_tensors++; } @@ -2733,7 +2733,7 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueRelu) { // for prologue and 1 for epilogue. int num_shared_mem_tensors = 0; int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; - for (const auto& tv : ir_utils::allTvs(&fusion)) { + for (const auto& tv : fusion.allTvs()) { if (tv->getMemoryType() == MemoryType::Shared) { num_shared_mem_tensors++; } diff --git a/tests/cpp/test_scatter_gather.cpp b/tests/cpp/test_scatter_gather.cpp index c8a39e88b01..fbac505ff0d 100644 --- a/tests/cpp/test_scatter_gather.cpp +++ b/tests/cpp/test_scatter_gather.cpp @@ -561,7 +561,7 @@ TEST_F(ScatterGatherTest, TakeAlongAxisIntermediateTensorPointwise1) { tv4->axis(-1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv4); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } diff --git a/tests/cpp/utils.h b/tests/cpp/utils.h index a0ee6764d06..4ae88f4064c 100644 --- a/tests/cpp/utils.h +++ b/tests/cpp/utils.h @@ -112,7 +112,7 @@ inline void clearL2Cache() { }; inline TensorView* loweredTv(TensorView* tv, kir::Kernel* kernel) { - auto used_tvs = ir_utils::allTvs(kernel); + auto used_tvs = kernel->allTvs(); TensorView* matching_tv = nullptr; for (auto lowered_tv : used_tvs) { if (lowered_tv->name() == tv->name()) { From cd0c30b5c437db4c61aab9ae07cb75f7983de1e2 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 3 Sep 2024 10:15:42 -0400 Subject: [PATCH 41/54] Cache the vectorization break point in schedulers (#2887) Cache the vectorization break point as it's a compile time value. --- csrc/scheduler/compile_time_info.h | 8 ++ csrc/scheduler/normalization_inner_outer.cpp | 17 ++-- csrc/scheduler/normalization_utils.cpp | 84 ++++++++++---------- csrc/scheduler/reduction.cpp | 18 +++-- csrc/scheduler/registry.cpp | 2 + csrc/scheduler/utils.cpp | 25 ++++++ csrc/scheduler/utils.h | 6 ++ 7 files changed, 108 insertions(+), 52 deletions(-) diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index 02b247d2ff7..b5a933c0ca1 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -46,6 +46,7 @@ enum class CompileTimeEntryType { CAN_SCHEDULE_TRANSPOSE, CAN_SCHEDULE_MUL_SUM_AS_MMA, LOGICAL_REORDER_MAP, + VECTORIZATION_BREAK_POINT_OF_RED_PROD }; //! Entry type definition class for `DOMAIN_MAP`, @@ -195,6 +196,13 @@ class LogicalReorderMap { CompileTimeEntryType::LOGICAL_REORDER_MAP; }; +class VectorizationBreakPointOfReductionProducer { + public: + using DataType = int64_t; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::VECTORIZATION_BREAK_POINT_OF_RED_PROD; +}; + //! Base abstract class for unified storage in `HeuristicSummary`, //! each entry in `HeuristicSummary` will be a subclass. class CompileTimeInfoBase : public PolymorphicBase { diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 51562346e72..852426cd908 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -967,12 +967,19 @@ std::shared_ptr getInnerOuterPersistentHeuristics( auto properties = scheduler_utils::getReductionProperties(fusion, runtime_info, ref_red_tv); auto reduced_tv = ir_utils::getSoleProducerTv(ref_red_tv); + + // Although properties contains runtime information + // "inner_most_dimension_ndims" is a compile time value + auto vec_break_point = HeuristicSummaryEntry< + HeuristicCompileTime::VectorizationBreakPointOfReductionProducer>( + data_cache, [&ref_red_tv, &reduced_tv, &properties]() { + return std::make_unique( + vectorize_helper::getVectorizationBreakPointOfReductionProducer( + ref_red_tv, reduced_tv, properties.inner_most_dimension_ndims)); + }); + const auto vectorize_factor = vectorize_helper::getVectorizationFactor( - runtime_info, - reduced_tv, - data_cache, - vectorize_helper::getVectorizationBreakPointOfReductionProducer( - ref_red_tv, reduced_tv, properties.inner_most_dimension_ndims)); + runtime_info, reduced_tv, data_cache, vec_break_point.get()); auto persistent_buffer_info_entry = HeuristicSummaryEntry( diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index ea580284f75..a8a50ebdee9 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -764,7 +764,7 @@ bool isProjectBufferToInputs( const bool can_use_smem_persistent, const bool check_projected_buffer_size) { // don't project if there are view ops and no buffer can be projected - bool can_project = ir_utils::getViewOps(fusion).empty() && + bool can_project = !persistent_buffer_info.has_view_ops && persistent_buffer_size_info.projected_persistent_buffer_size > 0; if (!can_project) { return false; @@ -796,28 +796,13 @@ bool isProjectBufferToInputs( } } - // check ops between persistent buffer and inputs. - // TODO: check more ops - bool has_exp_op = false; - const auto& projectable_buffers = - persistent_buffer_info.projectable_persistent_buffers; - auto all_inputs = ir_utils::inputTvsOf(projectable_buffers); - const auto all_exprs = StmtSort::getExprsBetween( - {all_inputs.begin(), all_inputs.end()}, - {projectable_buffers.begin(), projectable_buffers.end()}); - for (auto expr : all_exprs) { - if (expr->isA() && - expr->as()->getUnaryOpType() == UnaryOpType::Exp) { - has_exp_op = true; - } - // don't project if recompute requires rng op - if (expr->isA()) { - return false; - } + // don't project if recompute requires rng op + if (persistent_buffer_info.projection_with_rng_op) { + return false; } // free to project if no exp op - if (!has_exp_op) { + if (!persistent_buffer_info.projection_with_exp_op) { return true; } @@ -868,27 +853,34 @@ PersistentKernelProperties getPersistentKernelProperties( return std::make_unique>( scheduler_utils::getReductionTvs(fusion)); }); + auto& reduction_tvs = reduction_tv_entry.get(); NVF_ERROR(!reduction_tvs.empty(), "Need reduction tensor views to schedule."); auto ref_red_tv = reduction_tvs[0]; - // (1) fusion checks checkReductionTvForScheduling(fusion, ref_red_tv); - // (2) reduction properties - auto properties = + scheduler_utils::ReductionTvProperties properties; + TensorView* reduced_tv = nullptr; + int64_t vectorize_factor = -1; + + properties = scheduler_utils::getReductionProperties(fusion, runtime_info, ref_red_tv); + reduced_tv = ir_utils::getSoleProducerTv(ref_red_tv); - // (3) vectorization factor - auto reduced_tv = ir_utils::getSoleProducerTv(ref_red_tv); - auto vectorize_factor = vectorize_helper::getVectorizationFactor( - runtime_info, - reduced_tv, - data_cache, - vectorize_helper::getVectorizationBreakPointOfReductionProducer( - ref_red_tv, reduced_tv, properties.inner_most_dimension_ndims)); + // Although properties contains runtime information + // "inner_most_dimension_ndims" is a compile time value + auto vec_break_point = HeuristicSummaryEntry< + HeuristicCompileTime::VectorizationBreakPointOfReductionProducer>( + data_cache, [&ref_red_tv, &reduced_tv, &properties]() { + return std::make_unique( + vectorize_helper::getVectorizationBreakPointOfReductionProducer( + ref_red_tv, reduced_tv, properties.inner_most_dimension_ndims)); + }); + + vectorize_factor = vectorize_helper::getVectorizationFactor( + runtime_info, reduced_tv, data_cache, vec_break_point.get()); - // (4) info about persistent buffer auto persistent_buffer_info_entry = HeuristicSummaryEntry( data_cache, [&fusion]() { @@ -902,7 +894,7 @@ PersistentKernelProperties getPersistentKernelProperties( auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize( fusion, runtime_info, persistent_buffer_info, data_cache); - // (5) can project to input? + // Can project to input? // Figure out if we want to projet persistent buffers to the inputs for // exmaple if we have an input tensor t0 that's fp16: // @@ -921,7 +913,7 @@ PersistentKernelProperties getPersistentKernelProperties( // TODO: Fix projected persistent buffers with view // https://github.com/csarofeen/pytorch/issues/2054 - // (6) Project to input when it can reduce buffer size and the gains of + // Project to input when it can reduce buffer size and the gains of // reducing buffer size is larger than the pains of recalculations. bool can_use_smem_persistent = properties.inner_most_dimension_numel == properties.total_reduction_numel; @@ -932,11 +924,11 @@ PersistentKernelProperties getPersistentKernelProperties( persistent_buffer_size_info, heuristic, can_use_smem_persistent); - auto max_persistent_buffer_size = project_persistent_buffers + int64_t max_persistent_buffer_size = project_persistent_buffers ? persistent_buffer_size_info.projected_persistent_buffer_size : persistent_buffer_size_info.persistent_buffer_size; - // (7) info about input and output tensors + // Info about input and output tensors // Base max dtype and n_tensor_inputs on tensors that are vectorizable (i.e. // share inner dimension with data pattern we're looking at). // TODO: This might be better if it was the larger of input or outputs. Would @@ -949,9 +941,12 @@ PersistentKernelProperties getPersistentKernelProperties( scheduler_utils::getInputsOutputsWithInnerDim( reduced_tv, false, false)); }); - auto& unrollable_inputs_outputs = unrollable_inputs_outputs_entry.get(); + + // Info about ops in the fusion, used to set model specific parameters int64_t max_dtype_size = 1; int64_t n_tensor_inputs = 0; + + auto& unrollable_inputs_outputs = unrollable_inputs_outputs_entry.get(); for (auto tv : unrollable_inputs_outputs) { if (!tv->isFusionInput()) { continue; @@ -965,18 +960,23 @@ PersistentKernelProperties getPersistentKernelProperties( // zero. n_tensor_inputs = std::max(n_tensor_inputs, (int64_t)1); - // Info about ops in the fusion, used to set model specific parameters // Exp op typically used in softmax is expensive and needs more registers. bool has_exp_op = false; - for (auto expr : fusion->exprs()) { - if (expr->isA() && - expr->as()->getUnaryOpType() == UnaryOpType::Exp) { + + // Could save fusion->exprs() instead of doing this, but allTvs is already + // cached in fusion so using that for now. + for (auto tv : fusion->allTvs()) { + if (tv->definition() == nullptr) { + continue; + } + if (tv->definition()->isA() && + tv->definition()->as()->getUnaryOpType() == UnaryOpType::Exp) { has_exp_op = true; break; } } - // (9) return collected properties to get heuristics. + // Return collected properties to get heuristics. return PersistentKernelProperties{ .inner_most_dimension_numel = properties.inner_most_dimension_numel, .total_reduction_numel = properties.total_reduction_numel, diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index db09bea4412..f6feb34ba52 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -1301,12 +1301,20 @@ std::shared_ptr getReductionHeuristics( auto& unrollable_inputs_outputs = unrollable_inputs_outputs_entry.get(); + // Although properties contains runtime information + // "inner_most_dimension_ndims" is a compile time value + auto vec_break_point = HeuristicSummaryEntry< + HeuristicCompileTime::VectorizationBreakPointOfReductionProducer>( + data_cache, [&reduction_tv, &reduced_tv, &properties]() { + return std::make_unique( + vectorize_helper::getVectorizationBreakPointOfReductionProducer( + reduction_tv, + reduced_tv, + properties.inner_most_dimension_ndims)); + }); + const auto vectorize_factor = vectorize_helper::getVectorizationFactor( - runtime_info, - reduced_tv, - data_cache, - vectorize_helper::getVectorizationBreakPointOfReductionProducer( - reduction_tv, reduced_tv, properties.inner_most_dimension_ndims)); + runtime_info, reduced_tv, data_cache, vec_break_point.get()); // Base max dtype and n_tensor_inputs on tensors that are vectorizable (i.e. // share inner dimension with data pattern we're looking at). diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 047d9e479b1..7bddf484bb0 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -539,5 +539,7 @@ template class HeuristicSummaryEntry; template class HeuristicSummaryEntry< HeuristicCompileTime::CanScheduleTranspose>; template class HeuristicSummaryEntry; +template class HeuristicSummaryEntry< + HeuristicCompileTime::VectorizationBreakPointOfReductionProducer>; } // namespace nvfuser diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 6013f2838e6..d508f257fb5 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -619,6 +619,9 @@ PersistentBufferInfo persistentBuffers(Fusion* fusion) { PersistentBufferResolution::getResolutionPointsOf(fusion, buffer)); } + // don't project if there are view ops and no buffer can be projected + persistent_buffer_info.has_view_ops = !ir_utils::getViewOps(fusion).empty(); + // Find projectable persistent buffers auto reduction_tvs = getReductionTvs(fusion); for (auto persistent_buffer : persistent_buffer_info.persistent_buffers) { @@ -636,6 +639,11 @@ PersistentBufferInfo persistentBuffers(Fusion* fusion) { } } + // Projection analysis below + if (persistent_buffer_info.projectable_persistent_buffers.empty()) { + return persistent_buffer_info; + } + // Get a list of inputs of the projectable buffers auto all_inputs = ir_utils::inputTvsOf( persistent_buffer_info.projectable_persistent_buffers); @@ -666,6 +674,23 @@ PersistentBufferInfo persistentBuffers(Fusion* fusion) { } } + // check ops between persistent buffer and inputs. + // TODO: check more ops + const auto all_exprs = StmtSort::getExprsBetween( + {all_inputs.begin(), all_inputs.end()}, + {persistent_buffer_info.projectable_persistent_buffers.begin(), + persistent_buffer_info.projectable_persistent_buffers.end()}); + for (auto expr : all_exprs) { + if (expr->isA() && + expr->as()->getUnaryOpType() == UnaryOpType::Exp) { + persistent_buffer_info.projection_with_exp_op = true; + } + + if (expr->isA()) { + persistent_buffer_info.projection_with_rng_op = true; + } + } + return persistent_buffer_info; } diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 6323a2a059b..50583d5241f 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -195,6 +195,12 @@ struct PersistentBufferInfo { // Map unmappable dims to projectable_buffer_inputs std::unordered_set unamppable_dims_projected_to_inputs; + + // Some parameters used in + // normalization_scheduler_utils::isProjectBufferToInput + bool has_view_ops = false; + bool projection_with_exp_op = false; + bool projection_with_rng_op = false; }; // Buffers whos roots can't map to all producer roots based on compute at. These From 1484ce51729d313085247a16dedd18450fc7e0f2 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 3 Sep 2024 08:49:05 -0700 Subject: [PATCH 42/54] Fix MLP's reference implementation. (#2891) --- tests/cpp/test_multidevice_transformer.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 6ac3a810118..b684df18cc5 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -119,9 +119,10 @@ std::vector reference_mlp( at::Tensor w1, at::Tensor b1, at::ScalarType at_dtype) { - auto linear0 = at::matmul(x, w0).add(b0).to(at::kFloat); + auto linear0 = at::matmul(x, w0).to(at::kFloat) + b0.to(at::kFloat); auto gelu = at::gelu(linear0, "tanh"); - auto linear1 = at::matmul(gelu.to(at_dtype), w1).add(b1).to(at::kFloat); + auto linear1 = + at::matmul(gelu.to(at_dtype), w1).to(at::kFloat) + b1.to(at::kFloat); auto dropout = at::dropout(linear1, kDropoutProb, true); return {linear0, gelu, linear1, dropout}; } @@ -202,11 +203,10 @@ std::vector mlp( TensorView* b0_bcast = broadcast(b0, {false, true, false}); TensorView* linear1 = add(matmul1, b0_bcast); // GeLU - TensorView* linear1_ = castOp(DataType::Float, linear1); - TensorView* gelu = tanh_gelu(linear1_); - TensorView* gelu_ = castOp(dtype, gelu); + TensorView* gelu = tanh_gelu(linear1); + gelu = castOp(dtype, gelu); // Linear #2 - TensorView* local_matmul2 = matmul(gelu_, w1); + TensorView* local_matmul2 = matmul(gelu, w1); TensorView* matmul2 = sum(local_matmul2, {0}); // Allreduce TensorView* bcast_bias = broadcast(b1, {true, false}); TensorView* linear2 = add(matmul2, bcast_bias); @@ -371,7 +371,7 @@ std::vector mlp_backwards( } // namespace TEST_P(DistributedTransformerTest, MLP_Layer) { - auto dtype = GetParam(); + DataType dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -657,5 +657,7 @@ TEST_P(DistributedTransformerTest, Forward) { INSTANTIATE_TEST_SUITE_P( , DistributedTransformerTest, - testing::Values(DataType::Half, DataType::BFloat16)); + testing::Values(DataType::Half, DataType::BFloat16), + testing::PrintToStringParamName()); + } // namespace nvfuser From 04b241661c4b1b69afe5bba85af21625e6b2d842 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 3 Sep 2024 08:49:32 -0700 Subject: [PATCH 43/54] Move PipelineTest to test_multidevice_pipeline.cpp. (#2885) --- tests/cpp/multidevice.cpp | 93 ------------------- tests/cpp/multidevice.h | 21 ----- tests/cpp/test_multidevice_pipeline.cpp | 114 ++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 114 deletions(-) diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index c39ecff53a6..679fd54f6b8 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -150,99 +150,6 @@ at::Tensor MultiDeviceTest::shardTensor( return slice; } -void PipelineTest::validate(bool validate_with_prescribed_values) { - if (!validate_with_prescribed_values) { - // execute the fusion on one device without pipeline scheduling - auto fusion_copy = std::make_unique(*runtime->completeFusion()); - unshard(fusion_copy.get()); - FusionExecutorCache unsharded_fec(std::move(fusion_copy)); - ref_unsharded_outputs = unsharded_fec.runFusionWithInputs(unsharded_inputs); - } - - if (debug_print) { - std::stringstream ss; - std::string indent = " "; - ss << "Device " << communicator_->deviceId() - << "'s expected (unsharded) outputs:{\n"; - for (auto& t : ref_unsharded_outputs) { - ss << indent << t; - } - ss << "\n}"; - std::cout << ss.str() << std::endl; - } - - ASSERT_EQ(ref_unsharded_outputs.size(), outputs.size()); - for (int i : c10::irange(runtime->completeFusion()->outputs().size())) { - ASSERT_TRUE(runtime->completeFusion()->outputs().at(i)->isA()); - auto output_tv = - runtime->completeFusion()->outputs().at(i)->as(); - if (!output_tv->getDeviceMesh().has(communicator_->deviceId())) { - continue; - } - auto ref_output = shardTensor(ref_unsharded_outputs.at(i), output_tv); - auto obtained_output = outputs.at(i); - EXPECT_TRUE(torch::allclose(ref_output, obtained_output)) - << "Device " << communicator_->deviceId() << " has unexpected output " - << i << " corresponding to tv " << output_tv - << ". Expected values: " << ref_output - << ", obtained values: " << obtained_output; - } -} - -// Run and validate a pipeline -// with given (possibly sharded) inputs -void PipelineTest::executeAndValidate(bool validate_with_prescribed_values) { - ASSERT_EQ(unsharded_inputs.size(), fusion->inputs().size()); - for (int i : c10::irange(fusion->inputs().size())) { - ASSERT_TRUE(fusion->inputs().at(i)->isA()); - auto input_tv = fusion->inputs().at(i)->as(); - auto input = shardTensor(unsharded_inputs.at(i).toTensor(), input_tv); - inputs.push_back(input); - } - - if (debug_print) { - if (!communicator_->deviceId()) { - fusion->printKernel(); - } - std::stringstream ss; - std::string indent = " "; - ss << "Device " << communicator_->deviceId() << "'s inputs:{\n"; - for (auto& t : inputs) { - ss << indent << t; - } - ss << "\n}"; - std::cout << ss.str() << std::endl; - } - - runtime = std::make_unique( - std::move(fusion), *communicator_, host_ir_executor_params); - auto error_msg = runtime->validate(); - if (error_msg != "") { - GTEST_SKIP() << error_msg; - } - outputs = runtime->runWithInput(inputs); - - if (debug_print) { - if (!communicator_->deviceId()) { - runtime->print(); - } - std::stringstream ss; - std::string indent = " "; - ss << "Device " << communicator_->deviceId() << "'s outputs:{\n"; - for (auto& t : outputs) { - ss << indent << t; - } - ss << "\n}"; - std::cout << ss.str() << std::endl; - } - validate(validate_with_prescribed_values); -} - -PipelineTest::PipelineTest() { - fusion = std::make_unique(); - communicator_->setDefaultBackend(CommunicatorBackend::nccl); -} - } // namespace nvfuser int main(int argc, char** argv) { diff --git a/tests/cpp/multidevice.h b/tests/cpp/multidevice.h index 3279e61739d..24d1c323215 100644 --- a/tests/cpp/multidevice.h +++ b/tests/cpp/multidevice.h @@ -47,25 +47,4 @@ class MultiDeviceTest : public NVFuserTest { void waitForDebuggerAtRank(DeviceIdxType rank); }; -class PipelineTest : public MultiDeviceTest { - protected: - PipelineTest(); - - // Utility function used for validation in the tests. It compares the - // (sharded) outputs with ref_unsharded_outputs. if - // validate_with_prescribed_values is true, ref_unsharded_outputs is assumed - // to be set manually in the test body. Otherwise, ref_unsharded_outputs is - // computed by running a Fusion on a single device with the unsharded_inputs - void validate(bool validate_with_prescribed_values = false); - void executeAndValidate(bool validate_with_prescribed_values = false); - - std::unique_ptr runtime; - std::unique_ptr fusion; - std::vector inputs; - std::vector unsharded_inputs; - std::vector outputs; - std::vector ref_unsharded_outputs; - hir::HostIrExecutorParams host_ir_executor_params; -}; - } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index cf2e7952d79..558c31b61e3 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -40,6 +40,120 @@ namespace nvfuser { +class PipelineTest : public MultiDeviceTest { + protected: + PipelineTest(); + + // Utility function used for validation in the tests. It compares the + // (sharded) outputs with ref_unsharded_outputs. if + // validate_with_prescribed_values is true, ref_unsharded_outputs is assumed + // to be set manually in the test body. Otherwise, ref_unsharded_outputs is + // computed by running a Fusion on a single device with the unsharded_inputs + void validate(bool validate_with_prescribed_values = false); + void executeAndValidate(bool validate_with_prescribed_values = false); + + std::unique_ptr runtime; + std::unique_ptr fusion; + std::vector inputs; + std::vector unsharded_inputs; + std::vector outputs; + std::vector ref_unsharded_outputs; + hir::HostIrExecutorParams host_ir_executor_params; +}; + +void PipelineTest::validate(bool validate_with_prescribed_values) { + if (!validate_with_prescribed_values) { + // execute the fusion on one device without pipeline scheduling + auto fusion_copy = std::make_unique(*runtime->completeFusion()); + unshard(fusion_copy.get()); + FusionExecutorCache unsharded_fec(std::move(fusion_copy)); + ref_unsharded_outputs = unsharded_fec.runFusionWithInputs(unsharded_inputs); + } + + if (debug_print) { + std::stringstream ss; + std::string indent = " "; + ss << "Device " << communicator_->deviceId() + << "'s expected (unsharded) outputs:{\n"; + for (auto& t : ref_unsharded_outputs) { + ss << indent << t; + } + ss << "\n}"; + std::cout << ss.str() << std::endl; + } + + ASSERT_EQ(ref_unsharded_outputs.size(), outputs.size()); + for (int i : c10::irange(runtime->completeFusion()->outputs().size())) { + ASSERT_TRUE(runtime->completeFusion()->outputs().at(i)->isA()); + auto output_tv = + runtime->completeFusion()->outputs().at(i)->as(); + if (!output_tv->getDeviceMesh().has(communicator_->deviceId())) { + continue; + } + auto ref_output = shardTensor(ref_unsharded_outputs.at(i), output_tv); + auto obtained_output = outputs.at(i); + EXPECT_TRUE(torch::allclose(ref_output, obtained_output)) + << "Device " << communicator_->deviceId() << " has unexpected output " + << i << " corresponding to tv " << output_tv + << ". Expected values: " << ref_output + << ", obtained values: " << obtained_output; + } +} + +// Run and validate a pipeline +// with given (possibly sharded) inputs +void PipelineTest::executeAndValidate(bool validate_with_prescribed_values) { + ASSERT_EQ(unsharded_inputs.size(), fusion->inputs().size()); + for (int i : c10::irange(fusion->inputs().size())) { + ASSERT_TRUE(fusion->inputs().at(i)->isA()); + auto input_tv = fusion->inputs().at(i)->as(); + auto input = shardTensor(unsharded_inputs.at(i).toTensor(), input_tv); + inputs.push_back(input); + } + + if (debug_print) { + if (!communicator_->deviceId()) { + fusion->printKernel(); + } + std::stringstream ss; + std::string indent = " "; + ss << "Device " << communicator_->deviceId() << "'s inputs:{\n"; + for (auto& t : inputs) { + ss << indent << t; + } + ss << "\n}"; + std::cout << ss.str() << std::endl; + } + + runtime = std::make_unique( + std::move(fusion), *communicator_, host_ir_executor_params); + auto error_msg = runtime->validate(); + if (error_msg != "") { + GTEST_SKIP() << error_msg; + } + outputs = runtime->runWithInput(inputs); + + if (debug_print) { + if (!communicator_->deviceId()) { + runtime->print(); + } + std::stringstream ss; + std::string indent = " "; + ss << "Device " << communicator_->deviceId() << "'s outputs:{\n"; + for (auto& t : outputs) { + ss << indent << t; + } + ss << "\n}"; + std::cout << ss.str() << std::endl; + } + validate(validate_with_prescribed_values); +} + +PipelineTest::PipelineTest() { + fusion = std::make_unique(); + communicator_->setDefaultBackend(CommunicatorBackend::nccl); +} + // To run the following tests on several devices, pytorch must be installed with // the flag USE_DISTRIBUTED=1 and nccl support. With that, nvFuser is built by // default with NVFUSER_DISTRIBUTED defined. Then, on a node with at least 6 From cd4807e8a8ec9339d8f29661a4ad753fa9e19218 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Tue, 3 Sep 2024 08:56:52 -0700 Subject: [PATCH 44/54] Sharding propagation utility functions (#2836) Add a utility function `shardBetween` that shards all tensorviews between reference tensorviews and boundary tensorview like the reference(s), so that we can control shardingPropagation boundaries. Updates Transformer tests to disable the current sharding propagation pass and manually control sharding propagation boundaries with `shardBetween`. This let's us use the minimal number of manual sharding annotations (inputs, outputs, and reshardings). --- csrc/multidevice/utils.cpp | 48 +++++++++++++++ csrc/multidevice/utils.h | 20 ++++++ tests/cpp/test_multidevice_transformer.cpp | 71 ++++++++++------------ 3 files changed, 99 insertions(+), 40 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 9b32ba9c690..18189224994 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -231,6 +231,54 @@ void shardAllLike(TensorView* ref, std::vector tvs) { } } +void shardBetween( + const std::vector& from, + const std::vector& to, + TensorView* ref) { + std::vector from_tvs; + std::vector to_tvs; + for (auto expr : from) { + auto outputs = ir_utils::filterByType(expr->outputs()); + std::copy(outputs.begin(), outputs.end(), std::back_inserter(from_tvs)); + } + + for (auto expr : to) { + auto outputs = ir_utils::filterByType(expr->outputs()); + std::copy(outputs.begin(), outputs.end(), std::back_inserter(to_tvs)); + } + + shardBetween(from_tvs, to_tvs, ref); +} + +void shardBetween( + const std::vector& from, + const std::vector& to, + TensorView* ref) { + std::unordered_set boundary = {to.begin(), to.end()}; + for (auto tv : from) { + auto expr = tv->definition(); + if (expr == nullptr) { + continue; + } + auto inputs = ir_utils::filterByType(expr->inputs()); + std::copy( + inputs.begin(), inputs.end(), std::inserter(boundary, boundary.end())); + } + + std::unordered_set all_tvs = + scheduler_utils::getAllTvsFrom(from, boundary); + shardAllLike(ref, {all_tvs.begin(), all_tvs.end()}); + + // Remove DID parallelizations on reduction axes. + for (auto* tv : all_tvs) { + for (IterDomain* id : tv->getLoopDomain()) { + if (id->isReduction() && id->isDeviceDim()) { + id->parallelize(ParallelType::Serial); + } + } + } +} + int64_t requestedNumberOfDevices(Fusion* fusion) { DeviceIdxType max_index = 0; for (auto tv : fusion->allTvs()) { diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 4b2bba3279b..8dcb19c4548 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -86,8 +86,28 @@ bool haveDifferentShardings( // Returns whether a resharding expr reshards an inner axis bool isInnerResharding(Expr* expr); +// Shards all tensors in tvs like reference void shardAllLike(TensorView* ref, std::vector tvs); +// Shards all TVs between from and to AND between TVs created inside a fusion +// and to. This is required for (1) expressions like rng_uniform that create a +// TV inside a fusion that is not between a path from user visible TVs. (2) +// multi-output expressions may have output tensors that are not along a path to +// the fusion output which would not be reachable otherwise. (2) sharding +// propagation checks all TVs in the fusion are assigned a device mesh +// regardless if they are reachable. To keep the checks simple, we require all +// TVs are assigned a mesh if they exist in the fusion. +void shardBetween( + const std::vector& from, + const std::vector& to, + TensorView* ref); +// Same as above but using the outputs of the from and to expressions +// to form the from and to TVs. +void shardBetween( + const std::vector& from, + const std::vector& to, + TensorView* ref); + // Returns the devices involved in an expr std::set involvedDevices(Expr* expr); diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index b684df18cc5..84bfc6e5914 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -61,21 +61,6 @@ class DistributedTransformerTest }; namespace { -TensorView* replicated_dropout( - TensorView* x, - const double kProb, - const DeviceMesh mesh) { - // Sharding propagation breaks at rand_like because it creates a fresh TV. - TensorView* x_float = castOp(DataType::Float, x); - const double kScale = 1.0 / (1.0 - kProb); - TensorView* rand_vals = rand_like(x_float); - TensorView* mask = lt(rand_vals, IrBuilder::create(1.0 - kProb)); - TensorView* apply_mask = mul(x_float, mask); - TensorView* dropout = mul(apply_mask, IrBuilder::create(kScale)); - rand_vals->setDeviceMesh(mesh); - return dropout; -} - void validate( std::vector expected_out, std::vector out) { @@ -211,14 +196,12 @@ std::vector mlp( TensorView* bcast_bias = broadcast(b1, {true, false}); TensorView* linear2 = add(matmul2, bcast_bias); // Dropout - TensorView* dropout = replicated_dropout(linear2, kDropoutProb, mesh); - - // Sharding - // (TODO) TVs where sharding propagation breaks down: - // linear_int0: broadcasts where a device dim axis is broadcasted. - // rand_vals: rand_like creates a fresh new TV. - // TVs replicated on each device. - for (auto tv : {x, b1, matmul2, linear2, dropout}) { + Val* prob = IrBuilder::create(1.0 - kDropoutProb); + Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + auto dropout_result = dropout(linear2, prob, scale).output; + + // Manual sharding annotations + for (auto tv : {x, b1, matmul2, linear2, dropout_result}) { tv->setDeviceMesh(mesh); } for (auto tv : {w0, b0, w1, linear1, gelu}) { @@ -226,7 +209,7 @@ std::vector mlp( tv->axis(0)->parallelize(ParallelType::DIDx); } - return {linear1, gelu, linear2, dropout}; + return {linear1, gelu, linear2, dropout_result}; } std::vector mha( @@ -277,16 +260,18 @@ std::vector mha( TensorView* b1_bcast = broadcast(b1, {true, false}); TensorView* linear2 = add(mm2_ar, b1_bcast); // Dropout - TensorView* dropout = replicated_dropout(linear2, kDropoutProb, mesh); + Val* prob = IrBuilder::create(1.0 - kDropoutProb); + Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + auto dropout_result = dropout(linear2, prob, scale).output; - for (auto tv : {x, b1, mm2_ar, linear2, dropout}) { + for (auto tv : {x, b1, mm2_ar, linear2, dropout_result}) { tv->setDeviceMesh(mesh); } - for (auto tv : {w0, b0, w1, proj_bias_bcast, mm, mm2, qkv, sdpa_output}) { + for (auto tv : {w0, b0, w1, mm2, qkv, sdpa_output}) { tv->setDeviceMesh(mesh); tv->axis(0)->parallelize(ParallelType::DIDx); } - return {qkv, sdpa_output, linear2, dropout}; + return {qkv, sdpa_output, linear2, dropout_result}; } std::vector mlp_backwards( @@ -331,14 +316,9 @@ std::vector mlp_backwards( TensorView* matmul0_grad_w = transpose(matmul0_grad_w_t, 1, 2); TensorView* matmul0_grad_b = sum(gelu_grad, {1}); + // Manaul sharding annotations for (auto tv : - {x, - grad, - mask, - dropout_grad, - matmul1_grad_x, - matmul1_grad_b, - matmul0_grad_x}) { + {x, grad, mask, dropout_grad, matmul1_grad_b, matmul0_grad_x}) { tv->setDeviceMesh(mesh); } @@ -346,18 +326,14 @@ std::vector mlp_backwards( {w0, b0, w1, - matmul0, - matmul1_grad_x, matmul1_grad_w, - matmul1_grad_w_t, gelu_grad, - matmul0_grad_w_t, matmul0_grad_w, - matmul0_grad_x_partial, matmul0_grad_b}) { tv->setDeviceMesh(mesh); tv->axis(0)->parallelize(ParallelType::DIDx); } + std::vector outputs = { dropout_grad, matmul1_grad_w, @@ -395,6 +371,8 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { for (TensorView* tv : tvsout) { fusion->addOutput(tv); } + shardBetween({tvw0, tvb0, tvw1}, {tvsout[3]}, tvw0); + shardBetween({tvx, tvb1}, {tvsout[3]}, tvx); const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); @@ -455,6 +433,9 @@ TEST_P(DistributedTransformerTest, Multiheaded_Attention) { fusion->addOutput(tv); } + shardBetween({tvw0, tvb0, tvw1}, {tv_outs[3]}, tvw0); + shardBetween({tvx, tvb1}, {tv_outs[3]}, tvx); + const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x = at::randn({B * S, E}, options); @@ -513,6 +494,12 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { fusion->addOutput(tv); } + // Sharded: matmul1_grad_w, gelu_grad, matmul0_grad_w, matmul0_grad_b + shardBetween( + {w0, b0, w1}, {tv_outs[1], tv_outs[3], tv_outs[4], tv_outs[5]}, w0); + // Unsharded: dropout_grad, matmul1_grad_b, matmul0_grad_x + shardBetween({grad, x}, {tv_outs[0], tv_outs[2], tv_outs[6]}, grad); + const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto grad_ = at::randn({B * S, E}, options).to(at::kFloat); @@ -600,6 +587,10 @@ TEST_P(DistributedTransformerTest, Forward) { tv->setDeviceMesh(mesh); } + shardBetween({mha_in->definition()}, {mha_out->definition()}, mha_w0); + shardBetween({mlp_in->definition()}, {mlp_out->definition()}, mlp_w0); + shardBetween({x}, {mha_in}, x); + const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x_ = at::randn({B * S, E}, options).to(at::kFloat); From 305a45857f164cc9a3650d193c35d1e605cdcc99 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 3 Sep 2024 13:46:54 -0400 Subject: [PATCH 45/54] Cleanup NVTX Markers for dynamic shape profiling (#2886) Cleanup nvtx markers for dynamic latencies: remove those that weren't significant, add others that were, make schedulers consistent. --- csrc/executor.cpp | 13 +++------- csrc/kernel_cache.cpp | 27 ++++++++++++-------- csrc/scheduler/normalization_inner.cpp | 8 +++--- csrc/scheduler/normalization_inner_outer.cpp | 14 +++++----- csrc/scheduler/normalization_outer.cpp | 8 +++--- csrc/scheduler/normalization_utils.cpp | 3 ++- csrc/scheduler/pointwise.cpp | 8 +++--- csrc/scheduler/reduction.cpp | 10 +++----- csrc/scheduler/registry.cpp | 4 +++ csrc/scheduler/transpose.cpp | 7 +++-- csrc/scheduler/vectorize_helper.cpp | 6 +++++ csrc/transform_view.cpp | 1 - 12 files changed, 58 insertions(+), 51 deletions(-) diff --git a/csrc/executor.cpp b/csrc/executor.cpp index ec37aa747ff..f9f8bf7ba5c 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -563,8 +563,6 @@ std::pair, std::vector> inferShape( std::vector symbolic_sizes, std::vector expand_flags, ExpressionEvaluator& expr_eval) { - FUSER_PERF_SCOPE("inferShape"); - // Allocate should be provided for intermediates. We just need to // grab a chunk of memory of the size dicatated by // Allocate::shape(). Fusion outputs do not come with Allocate and @@ -1039,7 +1037,7 @@ std::vector allocateOutputs( const std::vector& output_info, const c10::Device& device, ExpressionEvaluator& ee) { - FUSER_PERF_SCOPE("allocateOutputs"); + FUSER_PERF_SCOPE("executor.cpp::allocateOutputs"); const auto num_outs = output_info.size(); @@ -1092,7 +1090,6 @@ int64_t FusionExecutor::computeSharedMemory( const std::vector& buffers, DataType index_type, int64_t smem_offset) { - FUSER_PERF_SCOPE("FusionExecutor::computeSharedMemory"); int64_t total = smem_offset; // align smem_offset at 16 bytes smem_offset = (smem_offset + 15) & (~15); @@ -1303,8 +1300,6 @@ std::vector FusionExecutor:: getIntermediateBufferInfo( ExpressionEvaluator& expr_eval, DataType index_type) { - FUSER_PERF_SCOPE("FusionExecutor::getIntermediateBufferInfo"); - std::vector global_buffers; const auto kernel = lowered_->kernel(); @@ -1361,7 +1356,6 @@ std::vector getOutputBufferInfo( ExpressionEvaluator& expr_eval, DataType index_dtype, const Fusion* fusion) { - FUSER_PERF_SCOPE("FusionExecutor::getOutbufferInfo"); std::vector outputs; outputs.reserve(fusion->outputs().size()); NVF_ERROR( @@ -1695,7 +1689,7 @@ void FusionExecutor::computeArgs( ExecutorEntry& entry, ExpressionEvaluator& expr_eval, const kir::Kernel* kernel) const { - FUSER_PERF_SCOPE("Initial GetArgsBuffers"); + FUSER_PERF_SCOPE("FusionExecutor::computeArgs"); const std::vector& params = kernel->parameters(); entry.args.resize(params.size()); @@ -1713,7 +1707,7 @@ void FusionExecutor::recomputeArgs( ExecutorEntry& entry, ExpressionEvaluator& expr_eval, const kir::Kernel* kernel) const { - FUSER_PERF_SCOPE("Recompute GetArgsBuffers"); + FUSER_PERF_SCOPE("FusionExecutor::recomputeArgs"); // assert(entry.init && "entry was never initialized"); const std::vector& params = kernel->parameters(); @@ -2021,7 +2015,6 @@ std::vector FusionExecutor::runFusion( std::vector intermediates; at::Tensor profile_buffer; { - FUSER_PERF_SCOPE("ExecutorRunFusion::IntermediateBufferAlloc"); for (const auto i : c10::irange(executor_entry->intermediates.size())) { const auto& buf_info = executor_entry->intermediates.at(i); bool has_expansion = false; diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 90d498103e2..65e886036c9 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -694,6 +694,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( const KernelArgumentHolder& args, std::optional forced_index_type) { // Check for id hit case (Path 1) + FUSER_PERF_SCOPE("FusionExecutorCache::getKernelRuntimeFor"); auto unique_id_opt = args.getCacheId(); NVF_CHECK( unique_id_opt.has_value(), @@ -773,6 +774,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( } if (!reusing) { + FUSER_PERF_SCOPE("FusionExecutorCache::getKernelRuntimeFor::!reusing"); // Paths 3 or 4 // cache miss, need to re-build an optimized graph for this case @@ -1318,7 +1320,6 @@ void FusionKernelRuntime::compileKernel( std::pair FusionKernelRuntime::getKernelConfig( const KernelArgumentHolder& args, SegmentedGroup* sg) { - FUSER_PERF_SCOPE("FusionKernelRuntime::getKernelConfig"); auto group_id = sg->groupId(); auto scheduler_entry = schedulers().at(group_id).get(); @@ -1363,6 +1364,7 @@ std::vector FusionKernelRuntime::runWithInputs( std::unordered_map FusionKernelRuntime:: runSegmentsWithInputs(KernelArgumentHolder& args) { + FUSER_PERF_SCOPE("FusionKernelRuntime::runSegmentsWithInputs"); NVF_ERROR( args.size() == segmented_fusion_->inputs().size(), "Inputs were not set up correctly, received ", @@ -1433,7 +1435,6 @@ const std::vector& FusionKernelRuntime:: void FusionKernelRuntime::updateHeuristicsLaunchParams( FusionHeuristics* update_heuristics) { - FUSER_PERF_SCOPE("FusionKernelRuntime::updateHeuristicsLaunchParams"); auto scheduler_list_length = heuristics_->heuristicsList().size(); NVF_ERROR( update_heuristics->heuristicsList().size() == scheduler_list_length); @@ -1461,7 +1462,6 @@ std::optional FusionKernelRuntime:: KernelArgumentHolder mutable_args(args); ArgumentManager args_manager( mutable_args, runtime_workspace_, segmented_fusion_->inputs()); - // Follow group run order for (int64_t group_id : c10::irange(num_groups)) { auto group_to_run = runtime_workspace_.group_run_order.at(group_id); @@ -1478,14 +1478,19 @@ std::optional FusionKernelRuntime:: } // Create PrecomputedValues for fusion segment - auto evaluator_precomputed_values = - std::make_unique(fusion_to_run); - evaluator_precomputed_values->bindInputs(group_runtime_inputs); - // TODO Remove binding the original fusion inputs when creating heuristics - // for fusion segment. - evaluator_precomputed_values->bindValues( - group_to_run->getCompleteFusionInputs(), args); - evaluator_precomputed_values->evaluate(); + std::unique_ptr evaluator_precomputed_values; + { + FUSER_PERF_SCOPE( + "FusionKernelRuntime::getMaybeHeuristicsFor::PrecomputedValues"); + evaluator_precomputed_values = + std::make_unique(fusion_to_run); + evaluator_precomputed_values->bindInputs(group_runtime_inputs); + // TODO Remove binding the original fusion inputs when creating heuristics + // for fusion segment. + evaluator_precomputed_values->bindValues( + group_to_run->getCompleteFusionInputs(), args); + evaluator_precomputed_values->evaluate(); + } // Get all tensorviews for segmented fusion std::vector all_tvs_for_fusion_to_run = diff --git a/csrc/scheduler/normalization_inner.cpp b/csrc/scheduler/normalization_inner.cpp index ea5fea09236..c168bee2eed 100644 --- a/csrc/scheduler/normalization_inner.cpp +++ b/csrc/scheduler/normalization_inner.cpp @@ -27,11 +27,12 @@ InnerPersistentKernelScheduler::InnerPersistentKernelScheduler( } void InnerPersistentKernelScheduler::schedule(Fusion* fusion) { - FUSER_PERF_SCOPE("Schedule InnerPersistent Fusion"); + FUSER_PERF_SCOPE("InnerPersistentKernelScheduler::schedule"); scheduleInnerPersistentKernel(fusion, reductionParams()); } bool InnerPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) { + FUSER_PERF_SCOPE("InnerPersistentKernelScheduler::canScheduleCompileTime"); return normalization_scheduler_utils::compileTimeCheck( fusion, heuristicType()); } @@ -83,7 +84,7 @@ bool InnerPersistentKernelScheduler::canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("InnerPersistentKernelScheduler::canSchedule"); + FUSER_PERF_SCOPE("InnerPersistentKernelScheduler::canScheduleRunTime"); auto reduction_tv_entry = HeuristicSummaryEntry( data_cache, [&fusion]() { @@ -162,6 +163,7 @@ void InnerPersistentKernelScheduler::computeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("InnerPersistentKernelScheduler::computeHeuristics"); params_ = getInnerPersistentHeuristics(fusion, runtime_info, data_cache); NVF_ERROR(params_ != nullptr); } @@ -1086,7 +1088,6 @@ std::shared_ptr getInnerPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getInnerPersistentHeuristics"); FusionGuard fg(fusion); // properties of the fusion @@ -1133,7 +1134,6 @@ std::shared_ptr getInnerPersistentHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getInnerPersistentHeuristicsFromIValue"); SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getInnerPersistentHeuristics(fusion, runtime_info, data_cache); } diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 852426cd908..77b8af77927 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -27,12 +27,14 @@ InnerOuterPersistentKernelScheduler::InnerOuterPersistentKernelScheduler( } void InnerOuterPersistentKernelScheduler::schedule(Fusion* fusion) { - FUSER_PERF_SCOPE("Schedule InnerOuterPersistent Fusion"); + FUSER_PERF_SCOPE("InnerOuterPersistentKernelScheduler::schedule"); scheduleInnerOuterPersistentKernel(fusion, reductionParams()); } bool InnerOuterPersistentKernelScheduler::canScheduleCompileTime( Fusion* fusion) { + FUSER_PERF_SCOPE( + "InnerOuterPersistentKernelScheduler::canScheduleCompileTime"); // common checks for all persistent heuristics if (!normalization_scheduler_utils::checkOpsAndInputs( fusion, heuristicType())) { @@ -353,6 +355,9 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( HeuristicSummary* data_cache, const std::vector& reduction_tvs, const int64_t vectorize_factor) { + FUSER_PERF_SCOPE( + "normalization_inner_outer::getPersistentBufferStorageParams"); + PersistentBufferStorageParams buffer_params; auto persistent_buffer_info_entry = @@ -573,7 +578,7 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("InnerOuterPersistentKernelScheduler::canSchedule"); + FUSER_PERF_SCOPE("InnerOuterPersistentKernelScheduler::canScheduleRunTime"); auto reduction_tv_entry = HeuristicSummaryEntry( data_cache, [&fusion]() { @@ -656,6 +661,7 @@ void InnerOuterPersistentKernelScheduler::computeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("InnerOuterPersistentKernelScheduler::computeHeuristics"); params_ = getInnerOuterPersistentHeuristics(fusion, runtime_info, data_cache); NVF_ERROR(params_ != nullptr); } @@ -931,7 +937,6 @@ std::shared_ptr getInnerOuterPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getInnerOuterPersistentHeuristics"); FusionGuard fg(fusion); auto reduction_tv_entry = @@ -1017,7 +1022,6 @@ std::shared_ptr getInnerOuterPersistentHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getInnerOuterPersistentHeuristicsFromIValue"); SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getInnerOuterPersistentHeuristics(fusion, runtime_info, data_cache); } @@ -1141,8 +1145,6 @@ void scheduleReductionCombinedOuter( void scheduleInnerOuterPersistentKernel( Fusion* fusion, const ReductionParams& rparams) { - FUSER_PERF_SCOPE("scheduleInnerOuterPersistentKernel"); - FusionGuard fg(fusion); // Grab the reduction, input, and output tensor views. dummy_outputs are diff --git a/csrc/scheduler/normalization_outer.cpp b/csrc/scheduler/normalization_outer.cpp index 4c2fe5ae8bb..1e2f15efa76 100644 --- a/csrc/scheduler/normalization_outer.cpp +++ b/csrc/scheduler/normalization_outer.cpp @@ -27,11 +27,12 @@ OuterPersistentKernelScheduler::OuterPersistentKernelScheduler( } void OuterPersistentKernelScheduler::schedule(Fusion* fusion) { - FUSER_PERF_SCOPE("Schedule OuterPersistent Fusion"); + FUSER_PERF_SCOPE("OuterPersistentKernelScheduler::schedule"); scheduleOuterPersistentKernel(fusion, reductionParams()); } bool OuterPersistentKernelScheduler::canScheduleCompileTime(Fusion* fusion) { + FUSER_PERF_SCOPE("OuterPersistentKernelScheduler::canScheduleCompileTime"); return normalization_scheduler_utils::compileTimeCheck( fusion, heuristicType()); } @@ -40,7 +41,7 @@ bool OuterPersistentKernelScheduler::canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("OuterPersistentKernelScheduler::canSchedule"); + FUSER_PERF_SCOPE("OuterPersistentKernelScheduler::canScheduleRunTime"); auto reduction_tv_entry = HeuristicSummaryEntry( data_cache, [&fusion]() { @@ -242,6 +243,7 @@ void OuterPersistentKernelScheduler::computeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("OuterPersistentKernelScheduler::computeHeuristics"); params_ = getOuterPersistentHeuristics(fusion, runtime_info, data_cache); NVF_ERROR(params_ != nullptr); } @@ -637,7 +639,6 @@ std::shared_ptr getOuterPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getOuterPersistentHeuristics"); FusionGuard fg(fusion); const auto& prop = @@ -663,7 +664,6 @@ std::shared_ptr getOuterPersistentHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getOuterPersistentHeuristicsFromIValue"); SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getOuterPersistentHeuristics(fusion, runtime_info, data_cache); } diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index a8a50ebdee9..c7a9542d614 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -845,7 +845,8 @@ PersistentKernelProperties getPersistentKernelProperties( SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache, ScheduleHeuristic heuristic) { - FUSER_PERF_SCOPE("getPersistentKernelProperties"); + FUSER_PERF_SCOPE( + "normalization_scheduler_utils::getPersistentKernelProperties"); auto reduction_tv_entry = HeuristicSummaryEntry( diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 2445ab79afa..9ffd9acd6d2 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -33,6 +33,7 @@ PointWiseScheduler::PointWiseScheduler( bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { if (scheduler_utils::isResharding(fusion)) { + FUSER_PERF_SCOPE("PointWiseScheduler::canScheduleCompileTime"); scheduler_debug_utils::canScheduleRejectReason( heuristicType(), "Fusion is resharding."); return false; @@ -82,6 +83,7 @@ bool PointWiseScheduler::canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("PointWiseScheduler::canScheduleRunTime"); auto can_schedule_transpose_entry = HeuristicSummaryEntry( data_cache, [fusion]() { @@ -98,7 +100,7 @@ bool PointWiseScheduler::canScheduleRunTime( } void PointWiseScheduler::schedule(Fusion* fusion) { - FUSER_PERF_SCOPE("Schedule PointWise Fusion"); + FUSER_PERF_SCOPE("PointWiseScheduler::schedule"); schedulePointwise(fusion, pointwiseParams()); } @@ -106,6 +108,7 @@ void PointWiseScheduler::computeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("PointWiseScheduler::computeHeuristics"); params_ = getPointwiseHeuristics(fusion, runtime_info, data_cache); NVF_ERROR(params_ != nullptr); } @@ -160,8 +163,6 @@ std::shared_ptr getPointwiseHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getPointwiseHeuristics"); - FusionGuard fg(fusion); // Incase any buffer is of type DataType::Index @@ -512,7 +513,6 @@ std::shared_ptr getPointwiseHeuristics( LaunchParams schedulePointwise( Fusion* fusion, const at::ArrayRef& runtime_inputs) { - FUSER_PERF_SCOPE("scheduleFusion"); auto params = getPointwiseHeuristics(fusion, runtime_inputs); NVF_ERROR(params != nullptr, "Could not schedule pointwise operation."); schedulePointwise(fusion, *params); diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index f6feb34ba52..3c293d25c2d 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -1056,17 +1056,19 @@ void ReductionScheduler::computeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("ReductionScheduler::computeHeuristics"); params_ = getReductionHeuristics(fusion, runtime_info, data_cache); NVF_ERROR(params_ != nullptr); } void ReductionScheduler::schedule(Fusion* fusion) { - FUSER_PERF_SCOPE("Schedule Single Reduction"); + FUSER_PERF_SCOPE("ReductionScheduler::schedule"); scheduleReduction(fusion, reductionParams()); } //! Check if the reduction heuristics apply in given fusion bool ReductionScheduler::canScheduleCompileTime(Fusion* fusion) { + FUSER_PERF_SCOPE("ReductionScheduler::canScheduleCompileTime"); if (scheduler_utils::isResharding(fusion)) { scheduler_debug_utils::canScheduleRejectReason( heuristicType(), "Fusion is resharding."); @@ -1210,6 +1212,7 @@ bool ReductionScheduler::canScheduleRunTime( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("ReductionScheduler::canScheduleRunTime"); return true; } @@ -1244,8 +1247,6 @@ std::shared_ptr getReductionHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getReductionHeuristics"); - SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getReductionHeuristics(fusion, runtime_info, data_cache); @@ -1255,8 +1256,6 @@ std::shared_ptr getReductionHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getReductionHeuristics"); - FusionGuard fg(fusion); auto reduction_tv_entry = @@ -1352,7 +1351,6 @@ std::shared_ptr getReductionHeuristics( // fusion is the input IR that will be modified by this function void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { - FUSER_PERF_SCOPE("scheduleReduction"); FusionGuard fg(fusion); bool unroll = rparams.isUnrolled(); diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 7bddf484bb0..6ba9346a845 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -24,6 +25,7 @@ SchedulerRuntimeInfo::SchedulerRuntimeInfo( const std::vector& all_tvs, std::optional forced_index_type) : complete_fusion_(complete_fusion) { + FUSER_PERF_SCOPE("SchedulerRuntimeInfo::SchedulerRuntimeInfo"); NVF_ERROR( complete_fusion_->inputs().size() == args.size(), "The provided fusion group expects ", @@ -174,6 +176,7 @@ bool checkCanSchedule( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) { + FUSER_PERF_SCOPE("SchedulerRuntimeInfo::checkCanSchedule"); // ExprEval scheduler only requires `canScheduleCompileTime` check and should // not use this fn. The following checks build the computeAt map that do not // work with SDPAOp. @@ -273,6 +276,7 @@ bool checkCanSchedule( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("SchedulerEntry::makeEntry"); std::unique_ptr scheduler_entry = nullptr; switch (sh) { case ScheduleHeuristic::NoOp: diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 7ed05ffbf2e..9271db9dcd4 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -28,6 +28,7 @@ TransposeScheduler::TransposeScheduler( } bool TransposeScheduler::canScheduleCompileTime(Fusion* fusion) { + FUSER_PERF_SCOPE("TransposeScheduler::canScheduleCompileTime"); if (scheduler_utils::isResharding(fusion)) { scheduler_debug_utils::canScheduleRejectReason( heuristicType(), "Fusion is resharding."); @@ -112,7 +113,7 @@ bool TransposeScheduler::canScheduleRunTime( } void TransposeScheduler::schedule(Fusion* fusion) { - FUSER_PERF_SCOPE("Schedule Transpose Fusion"); + FUSER_PERF_SCOPE("TransposeScheduler::schedule"); scheduleTranspose(fusion, transposeParams()); } @@ -120,6 +121,7 @@ void TransposeScheduler::computeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("TransposeScheduler::computeHeuristics"); params_ = getTransposeHeuristics(fusion, runtime_info, data_cache); NVF_ERROR(params_ != nullptr); } @@ -832,8 +834,6 @@ std::shared_ptr getTransposeHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getTransposeHeuristics"); - FusionGuard fg(fusion); // Incase any buffer is of type DataType::Index @@ -1049,7 +1049,6 @@ std::shared_ptr getTransposeHeuristics( LaunchParams scheduleTranspose( Fusion* fusion, const at::ArrayRef& runtime_inputs) { - FUSER_PERF_SCOPE("scheduleFusion"); auto params = getTransposeHeuristics(fusion, runtime_inputs); NVF_ERROR(params != nullptr, "Could not schedule transpose operation."); scheduleTranspose(fusion, *params); diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 22103812976..4c650388c84 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -802,6 +803,8 @@ int64_t getVectorizationFactor( HeuristicSummary* data_cache, int64_t break_point, const std::unordered_map& logical_reorder_map) { + FUSER_PERF_SCOPE("vectorize_helper::getVectorizationFactor"); + auto vectorizable_inputs_outputs_entry = HeuristicSummaryEntry( data_cache, [&reference_tv]() { @@ -916,6 +919,9 @@ int64_t getVectorizationBreakPointOfReductionProducer( TensorView* reduction_consumer, TensorView* reduction_producer, int64_t consumer_innermost_ndims) { + FUSER_PERF_SCOPE( + "vectorize_helper::getVectorizationBreakPointOfReductionProducer"); + NVF_ERROR( reduction_consumer->definition() != nullptr && ir_utils::isReductionOp(reduction_consumer->definition()) && diff --git a/csrc/transform_view.cpp b/csrc/transform_view.cpp index d37a7b4281e..7039f336a59 100644 --- a/csrc/transform_view.cpp +++ b/csrc/transform_view.cpp @@ -759,7 +759,6 @@ AnalyzeViewResult analyzeView( const TensorView* original_view_tv, const std::vector& original_sizes, const std::vector& new_sizes) { - FUSER_PERF_SCOPE("analyzeView"); if (original_sizes.empty()) { NVF_ERROR( std::all_of( From f710a9217faac1dfe29a893d18fc7b261a17c561 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:55:00 -0400 Subject: [PATCH 46/54] Update expected error message in cat error tests (#2892) This was broken in #2872 but this should fix CI. --- tests/python/opinfo_input_generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index 779f84bfb3b..fa5abe5322a 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -311,7 +311,7 @@ def cat_error_generator(op, dtype=torch.float32, requires_grad: bool = False, ** "Unexpected number of dimensions", ) # All tensors must have same shape except for the cat dimension - shape_mismatch = (([(2, 3), (4, 5)], 0), RuntimeError, "known_size == this_size") + shape_mismatch = (([(2, 3), (4, 5)], 0), RuntimeError, "Tried to bind to a value") error_cases = [ empty_input_tensors, From 0fea9e75d59dd1be2e8071479da6e645966d4980 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:05:23 -0400 Subject: [PATCH 47/54] Skip op checks in checkCanSchedule if we are given a data_cache (#2893) During scheduling, no data cache is provided when calling `checkCanSchedule`. However, when getting heuristics for a segment, specifically when we are checking whether we can re-use a compiled runtime for new input shapes, we do pass a `data_cache` so that we can re-use some objects for scheduling. In these cases, we have already passed the compile time checks so the check is redundant. In this PR, we avoid the scheduler-agnostic compile time checks added in #2526 when the `data_cache` is given. --- csrc/scheduler/registry.cpp | 40 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 6ba9346a845..1934235f41b 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -183,29 +183,27 @@ bool checkCanSchedule( NVF_ERROR(SchedulerType::heuristicType() != ScheduleHeuristic::ExprEval); FusionGuard fg(fusion); + // If a data cache is given, the compile time part doesn't need to be checked, + // since during segmentation the segmenter will call + // SchedulerEntry::proposeHeuristics which doesn't pass a data_cache. + if (data_cache == nullptr) { + // Fusions with `SdpaFwdOp/SdpaBwdOp` are only accepted in `ExprEval` + // scheduler, all other schedulers should reject them. + if (ir_utils::hasOpsOfType(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + SchedulerType::heuristicType(), "SdpaOps are not supported."); + return false; + } - // Fusions with `SdpaFwdOp/SdpaBwdOp` are only accepted in `ExprEval` - // scheduler, all other schedulers should reject them. - if (ir_utils::hasOpsOfType(fusion)) { - scheduler_debug_utils::canScheduleRejectReason( - SchedulerType::heuristicType(), "SdpaOps are not supported."); - return false; - } - - // Fusions with `MatmulOp, LinearOp, MmaOp` can only be accepted by Matmul - // scheduler. - if (SchedulerType::heuristicType() != ScheduleHeuristic::Matmul && - ir_utils::hasOpsOfType(fusion)) { - scheduler_debug_utils::canScheduleRejectReason( - SchedulerType::heuristicType(), "Matmul ops are not supported."); - return false; - } + // Fusions with `MatmulOp, LinearOp, MmaOp` can only be accepted by Matmul + // scheduler. + if (SchedulerType::heuristicType() != ScheduleHeuristic::Matmul && + ir_utils::hasOpsOfType(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + SchedulerType::heuristicType(), "Matmul ops are not supported."); + return false; + } - // If a data cache is given, the compile time part doesn't need to be checked, - // since for all current use cases - // it has to pass all the compile time checks to create a data cache for this - // fusion. - if (!data_cache) { if (!registry_utils::isConnectedFusionGraph(fusion)) { scheduler_debug_utils::canScheduleRejectReason( SchedulerType::heuristicType(), From 7d8a9212f76436ee9da6222fc7813a2b604856f7 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 3 Sep 2024 11:05:49 -0700 Subject: [PATCH 48/54] Clean test_multidevice_sharding.cpp. (#2883) --- tests/cpp/test_multidevice_sharding.cpp | 76 +++++++++++++++---------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index fbccf278911..9144636c059 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -16,35 +16,36 @@ namespace nvfuser { // params: concrete vs symbolic input, sharded axis -class MultideviceShardingTest +class MultiDeviceReductionTest : public MultiDeviceTest, public testing::WithParamInterface> {}; -// Test memory allocation of multidevice fusion with unsharded inputs -// and sharded intermediates, outputs. -TEST_P(MultideviceShardingTest, UnshardedGlobalInput) { - auto [creates_concrete_tensor, sharded_dim] = GetParam(); +// Test multidevice fusion with unsharded inputs and sharded intermediates, +// outputs. +TEST_P(MultiDeviceReductionTest, UnshardedInput_ShardedOutput) { + auto [creates_concrete_tensor, sharded_input_dim] = GetParam(); + auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - int num_devices = communicator_->size(); + const int num_devices = communicator_->size(); auto mesh = DeviceMesh::createForNumDevices(num_devices); - std::vector input_size = {2, 3, 2, num_devices}; - input_size[sharded_dim] = num_devices; + std::vector input_shape = {2, 3, 2, num_devices}; + input_shape[sharded_input_dim] = num_devices; TensorView* tv0 = creates_concrete_tensor - ? makeContigConcreteTensor(input_size) - : makeContigTensor(4); + ? makeContigConcreteTensor(input_shape) + : makeContigTensor(input_shape.size()); TensorView* tv1 = set(tv0); TensorView* tv2 = add(tv1, tv1); - TensorView* tv3 = sum(tv2, {sharded_dim}); + TensorView* tv3 = sum(tv2, {sharded_input_dim}); fusion->addInput(tv0); fusion->addOutput(tv1); fusion->addOutput(tv2); fusion->addOutput(tv3); - tv1->axis(sharded_dim)->parallelize(ParallelType::DIDx); - tv2->axis(sharded_dim)->parallelize(ParallelType::DIDx); + tv1->axis(sharded_input_dim)->parallelize(ParallelType::DIDx); + tv2->axis(sharded_input_dim)->parallelize(ParallelType::DIDx); tv3->axis(-1)->parallelize(ParallelType::DIDx); std::vector tvs = {tv0, tv1, tv2, tv3}; @@ -52,31 +53,31 @@ TEST_P(MultideviceShardingTest, UnshardedGlobalInput) { tv->setDeviceMesh(mesh); } - auto x0 = at::randn(input_size, tensor_options); + auto x0 = at::randn(input_shape, tensor_options); std::vector inputs = {x0}; auto x1 = shardTensor(x0, tv1); auto x2 = x1 + x1; - auto x3 = shardTensor(at::sum(x0 + x0, {sharded_dim}), tv3); + auto x3 = shardTensor(at::sum(x0 + x0, {sharded_input_dim}), tv3); FusionExecutorCache fec(std::move(fusion)); auto outputs = fec.runFusionWithInputs(inputs); testValidate(fec.fusion(), outputs, inputs, {x1, x2, x3}, __LINE__, __FILE__); } -// Test memory allocation of multidevice fusion with sharded input -// and replicated intermediates and output. -TEST_P(MultideviceShardingTest, ShardGlobalInput) { +// Test multidevice fusion with sharded input and replicated intermediates and +// output. +TEST_P(MultiDeviceReductionTest, ShardedInput_ReplicatedOutput) { auto [creates_concrete_tensor, sharded_dim] = GetParam(); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); int num_devices = communicator_->size(); auto mesh = DeviceMesh::createForNumDevices(num_devices); - std::vector unsharded_input_size = {3, 2, 5}; - unsharded_input_size[sharded_dim] = num_devices; + std::vector unsharded_input_shape = {3, 2, 5}; + unsharded_input_shape[sharded_dim] = num_devices; TensorView* tv0 = creates_concrete_tensor - ? makeContigConcreteTensor(unsharded_input_size) - : makeContigTensor(unsharded_input_size.size()); + ? makeContigConcreteTensor(unsharded_input_shape) + : makeContigTensor(unsharded_input_shape.size()); TensorView* tv1 = set(tv0); TensorView* tv2 = add(tv1, tv1); fusion->addInput(tv0); @@ -90,7 +91,7 @@ TEST_P(MultideviceShardingTest, ShardGlobalInput) { tv->setDeviceMesh(mesh); } - auto x1 = at::randn(unsharded_input_size, tensor_options); + auto x1 = at::randn(unsharded_input_shape, tensor_options); std::vector inputs = {shardTensor(x1, tv0)}; auto x2 = x1 * 2; FusionExecutorCache fec(std::move(fusion)); @@ -100,7 +101,7 @@ TEST_P(MultideviceShardingTest, ShardGlobalInput) { INSTANTIATE_TEST_SUITE_P( , - MultideviceShardingTest, + MultiDeviceReductionTest, testing::Combine(testing::Bool(), testing::Values(0, 1)), [](const testing::TestParamInfo>& info) -> std::string { @@ -115,7 +116,7 @@ INSTANTIATE_TEST_SUITE_P( return os.str(); }); -TEST_F(MultideviceShardingTest, Slice) { +TEST_F(MultiDeviceTest, Slice) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto mesh = DeviceMesh::createForNumDevices(communicator_->size()); @@ -150,7 +151,7 @@ TEST_F(MultideviceShardingTest, Slice) { __FILE__); } -TEST_F(MultideviceShardingTest, LayerNorm) { +TEST_F(MultiDeviceTest, LayerNorm) { std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto mesh = DeviceMesh::createForNumDevices(communicator_->size()); @@ -190,7 +191,7 @@ TEST_F(MultideviceShardingTest, LayerNorm) { __FILE__); } -TEST_F(MultideviceShardingTest, Issue2758) { +TEST_F(MultiDeviceTest, Issue2758) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -230,6 +231,9 @@ TEST_F(MultideviceShardingTest, Issue2758) { __FILE__); } +class MultiDeviceBroadcastTest : public MultiDeviceTest, + public testing::WithParamInterface {}; + // This test and the following `ExpandedBroadcast` test verify the expression // evaluator correctly binds the extent of a broadcast dimension to 1 and the // expanded extent to the tensor size. There used to be a bug where it @@ -237,7 +241,9 @@ TEST_F(MultideviceShardingTest, Issue2758) { // // `b(DID{i0})` and `b(i0)` bear the same semantics. The former is used more // often due to how parallelizeAllLike is implemented. -TEST_F(MultideviceShardingTest, Broadcast) { +TEST_P(MultiDeviceBroadcastTest, NotExpanded) { + const bool parallelizes_broadcast = GetParam(); + auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -250,7 +256,9 @@ TEST_F(MultideviceShardingTest, Broadcast) { .shape({1, -1}) .build(); in->setDeviceMesh(mesh); - in->axis(0)->parallelize(ParallelType::DIDx); + if (parallelizes_broadcast) { + in->axis(0)->parallelize(ParallelType::DIDx); + } TensorView* out = set(in); fusion->addInput(in); fusion->addOutput(out); @@ -262,7 +270,9 @@ TEST_F(MultideviceShardingTest, Broadcast) { testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); } -TEST_F(MultideviceShardingTest, ExpandedBroadcast) { +TEST_P(MultiDeviceBroadcastTest, Expanded) { + const bool parallelizes_broadcast = GetParam(); + auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -276,7 +286,9 @@ TEST_F(MultideviceShardingTest, ExpandedBroadcast) { .expanded({true, false}) .build(); in->setDeviceMesh(mesh); - in->axis(0)->parallelize(ParallelType::DIDx); + if (parallelizes_broadcast) { + in->axis(0)->parallelize(ParallelType::DIDx); + } TensorView* out = set(in); fusion->addInput(in); fusion->addOutput(out); @@ -288,4 +300,6 @@ TEST_F(MultideviceShardingTest, ExpandedBroadcast) { testValidate(fec.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); } +INSTANTIATE_TEST_SUITE_P(, MultiDeviceBroadcastTest, testing::Bool()); + } // namespace nvfuser From 5582769e6de4108c0a3c559cdf183778817b0954 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 3 Sep 2024 14:49:22 -0700 Subject: [PATCH 49/54] Only the first local rank prints pre-segmenter fusion IR. (#2895) --- csrc/kernel_cache.cpp | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 65e886036c9..b6894cd2c18 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -7,6 +7,14 @@ // clang-format on #include +#include +#include + +#include +#include +#include +#include + #include #include #include @@ -16,21 +24,13 @@ #include #include #include +#include #include #include #include #include -#include -#include #include -#include -#include -#include - -#include -#include - namespace nvfuser { namespace { @@ -1019,9 +1019,14 @@ FusionKernelRuntime::FusionKernelRuntime( fusion.get()); if (isDebugDumpEnabled(DebugDumpOption::FusionIrPreseg)) { - debug() << "Fusion IR after pre-segmenter optimization passes:" - << std::endl; - fusion->printMath(); + const auto& communicator = Communicator::getInstance(); + // Only the first local rank will print. Pre-segmenter fusion IR is device + // agnostic, so letting all ranks print isn't any more useful. + if (!communicator.is_available() || communicator.local_rank() == 0) { + debug() << "Fusion IR after pre-segmenter optimization passes:" + << std::endl; + fusion->printMath(); + } } // SchedulerRuntimeInfo modifies the fusion, so it is required for both From 524703293247a66aa4c187c10868989fe60cda39 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 3 Sep 2024 15:02:43 -0700 Subject: [PATCH 50/54] HopperSS.FullSwizzle (#2880) The SS version of https://github.com/NVIDIA/Fuser/pull/2710 --- csrc/device_lower/pass/index.cpp | 8 +- tests/cpp/test_mma.cpp | 169 ++++++++++++++++++++++++++++++- 2 files changed, 175 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index a833bfafe69..ce30a568c91 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1616,8 +1616,14 @@ void IndexLowering::handle(const MmaOp* mma) { // TODO: This is a temporary solution and only supports a single tile in // smem. auto tv = mma->inA()->as(); - auto base_addr = IrBuilder::baseAddressExpr(tv); auto swizzle = getSwizzleMode(tv); + // Because the entire tile is parallelized on MMA, which are trivial + // loops and always have zero loop variables, the result of lowerSrcIndex + // will be the address of the first element of the tile, which happens to + // be the information we need to provide to the hardware. + auto base_addr = lowerSrcIndex(tv, mma->out(), {}, true) + ->as() + ->index(); int64_t leading_bytes = core_matrix_outer_size * getBytesFromSwizzle(swizzle); // swizzle period in bytes int64_t inner_size = diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index b7dab16e626..d4e8d8d614a 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -997,7 +997,6 @@ TEST_P(HopperSS, SingleTile) { moveInnerBroadcastLeft(tv0); moveInnerBroadcastLeft(tv1); - // Hopper tensor core assumes K major, so we are using !transpose_a here. tv0->applyMmaSwizzle(swizzle_a); tv1->applyMmaSwizzle(swizzle_b); @@ -1030,6 +1029,174 @@ TEST_P(HopperSS, SingleTile) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +// See the note in HopperRS.FullSwizzle for the explanation of this test. +TEST_P(HopperSS, FullSwizzle) { + Fusion fusion; + FusionGuard fg(&fusion); + + bool m_is_inner = layout == MmaLayout::NT || layout == MmaLayout::NN; + auto swizzle_size_a = getBytesFromSwizzle(swizzle_a) / dataTypeSize(dtype); + auto inner_size_a = m_is_inner ? getM(macro) : getK(macro); + bool multiple_a = swizzle_size_a / inner_size_a > 1; + + bool n_is_inner = layout == MmaLayout::TT || layout == MmaLayout::NT; + auto swizzle_size_b = getBytesFromSwizzle(swizzle_b) / dataTypeSize(dtype); + auto inner_size_b = n_is_inner ? getN(macro) : getK(macro); + bool multiple_b = swizzle_size_b / inner_size_b > 1; + + if (!multiple_a && !multiple_b) { + GTEST_SKIP() + << "Already tested in SingleTile, not interested in testing it again"; + } + + if ((multiple_a && swizzle_size_a % inner_size_a != 0) || + (multiple_b && swizzle_size_b % inner_size_b != 0)) { + GTEST_SKIP() + << "We will be using swizzle size as CTA tile size, so it must be divisible"; + } + + int64_t m = (multiple_a && m_is_inner) ? swizzle_size_a : getM(macro); + int64_t n = (multiple_b && n_is_inner) ? swizzle_size_b : getN(macro); + int64_t k1 = (multiple_a && !m_is_inner) ? swizzle_size_a : getK(macro); + int64_t k2 = (multiple_b && !n_is_inner) ? swizzle_size_b : getK(macro); + + if (k1 != k2) { + GTEST_SKIP() + << "This test assumes the CTA tile size of A and B must be the same"; + } + + auto shapes = matmulAtInputShape3DHopperSS(m, n, k1, layout); + + auto tv0 = makeConcreteTensor(shapes.first, dtype); + auto tv1 = makeConcreteTensor(shapes.second, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Just doing a gmem->smem copy + tv0 = set(tv0); + tv0->setMemoryType(MemoryType::Shared); + tv1 = set(tv1); + tv1->setMemoryType(MemoryType::Shared); + + int axes = 0; + switch (layout) { + case MmaLayout::NT: + axes = 0; + break; + case MmaLayout::TT: + case MmaLayout::NN: + axes = 1; + break; + case MmaLayout::TN: + axes = 2; + break; + default: + NVF_ERROR("Invalid layout"); + } + auto tv2 = fusedMultiplySum(tv0, tv1, {axes}); + + // Reorder the accumulator as [M, N, K] + switch (layout) { + case MmaLayout::TT: + // [M, K, N] -> [M, N, K] + tv2->reorder({{-2, -1}}); + break; + case MmaLayout::TN: + // [M, N, K] + break; + case MmaLayout::NT: + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + break; + case MmaLayout::NN: + // [N, K, M] -> [M, N, K] + tv2->reorder({{-1, -3}}); + break; + default: + NVF_ERROR("Invalid layout"); + } + tv2->commitLeafToLogical(); + + fusion.addOutput(tv2); + + auto mma_ops = ir_utils::getOpsOfType(&fusion); + NVF_CHECK( + 1 == mma_ops.size(), + "Invalid number of MmaOp instances in fusion definition, expected 1, got ", + mma_ops.size()); + mma_ops.front()->setMacro(macro); + + auto tv2c = tv2->cacheBefore(); + + // Bring related dims to innermost, that is: + // - Reorder tv0 as [1, M, K] or [1, K, M] + // - Reorder tv1 as [1, N, K] or [1, K, N] + moveInnerBroadcastLeft(tv0); + moveInnerBroadcastLeft(tv1); + + // Just schedule tv0 and tv1 the same way as in SingleTile. Note that although + // the schedule are the same, the memory layout is different. + // For example, assume that the inner size is 16, and the swizzle size is 64. + // For the case of SingleTile, the input tensor size will be 16, so the inner + // dimension will be split as: + // 1, 64 = split(16, 64) + // For the case of FullSwizzle, the input tensor size will be 64, so the inner + // dimension will be split as: + // 1, 64 = split(64, 64) + tv0->applyMmaSwizzle(swizzle_a); + tv1->applyMmaSwizzle(swizzle_b); + + naivelyParallelize(tv0); + naivelyParallelize(tv1); + + // [M, N, K] + int64_t inline_pos = 0; + if (multiple_a && m_is_inner) { + tv2c->split(-3, getM(macro)); + tv2->split(-2, getM(macro)); + inline_pos++; + } + if (multiple_b && n_is_inner) { + tv2c->split(-2, getN(macro)); + tv2c->reorder({{-3, -4}}); + tv2->split(-1, getN(macro)); + tv2->reorder({{-2, -3}}); + inline_pos++; + } + if ((multiple_a && !m_is_inner) || (multiple_b && !n_is_inner)) { + tv2c->split(-1, getK(macro)); + tv2c->reorder({{-2, -4}}); + } + // [Mo, No, Ko, Mi, Ni, Ki] + + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv2c->getLoopDomain()); + tv2c->setLoopDomain(s.as()); + tv2c->setAllocationDomain(s.as(), true); + } + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv2->getLoopDomain()); + tv2->setLoopDomain(s.as()); + } + + tv2c->inlineAt(inline_pos); + + auto inputs = + matmulAtInput3DHopperSS(m, n, k1, layout, data_type_to_aten(dtype)); + + FusionExecutor fe; + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.squeeze().to(at::kFloat), + inputs.second.squeeze().to(at::kFloat), + layout); + EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); +} + TEST_P(HopperSS, SingleTileWithTMALoad) { Fusion fusion; FusionGuard fg(&fusion); From 8147bfdd52a84e448c3fc893545b17c2b7bec40e Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 3 Sep 2024 19:59:44 -0700 Subject: [PATCH 51/54] Print the dtype of a TensorView. (#2889) This helps me to debug some numerical problems because nvFuser has a non-obvious dtype promotion rule. However, it has a large blast radius so I'd love to hear your opinions. Example: ``` [==========] Running 1 test from 1 test suite. [----------] Global test environment set-up. [----------] 1 test from AliasTest [ RUN ] AliasTest.QKVSplitBackprop Fusion IR after pre-segmenter optimization passes: Inputs: T0_g_float[ iS0{16}, iS1{128}, iS2{768} ] T1_g_float[ iS3{16}, iS4{128}, iS5{768} ] T2_g_float[ iS6{16}, iS7{128}, iS8{768} ] Outputs: T7_g_float[ rS24{16}, rS25{128}, iS26{2304} ] T8_g_float[ iS32{( 16 * 128 )}rf, iS29{2304} ] T9_g_float[ iS34{2304}, iS33{( 16 * 128 )} ] %kernel_math { T3_l_float[ iS9{16}, iS10{128}, iS12{2304}rf ] = pad( T0_g_float[ iS0{16}, iS1{128}, iS2{768} ], {0, 0, 0, 0, 0, 1536} ) i23 = 0 + 768; T4_l_float[ iS13{16}, iS14{128}, iS16{( ( ( 0 + 768 ) + 768 ) + 768 )}rf ] = pad( T1_g_float[ iS3{16}, iS4{128}, iS5{768} ], {0, 0, 0, 0, i23, 768} ) i36 = i23 + 768; T5_l_float[ iS17{16}, iS18{128}, iS20{( ( ( 0 + 768 ) + 768 ) + 768 )}rf ] = pad( T2_g_float[ iS6{16}, iS7{128}, iS8{768} ], {0, 0, 0, 0, i36, 0} ) T6_l_float[ iS21{16}, iS22{128}, iS23{2304} ] = cat( T3_l_float[ iS9{16}, iS10{128}, iS12{2304}rf ], T4_l_float[ iS13{16}, iS14{128}, iS16{( ( ( 0 + 768 ) + 768 ) + 768 )}rf ], T5_l_float[ iS17{16}, iS18{128}, iS20{( ( ( 0 + 768 ) + 768 ) + 768 )}rf ], 2 ) T7_g_float[ rS24{16}, rS25{128}, iS26{2304} ] = reduction( T6_l_float[ iS21{16}, iS22{128}, iS23{2304} ], op = add, initial value = float(0), allreduce = false ) T10_l_float[ iS35{16}, iS36{128}, iS37{2304} ] = SegmenterSet( T6_l_float[ iS21{16}, iS22{128}, iS23{2304} ] ) T8_g_float[ iS32{( 16 * 128 )}rf, iS29{2304} ] = view( T10_l_float[ iS35{16}, iS36{128}, iS37{2304} ] ) T9_g_float[ iS34{2304}, iS33{( 16 * 128 )} ] = Set.Permute( T8_g_float[ iS32{( 16 * 128 )}rf, iS29{2304} ], cache_op=Streaming ) } // %kernel_math [ OK ] AliasTest.QKVSplitBackprop (501 ms) [----------] 1 test from AliasTest (501 ms total) [----------] Global test environment tear-down [==========] 1 test from 1 test suite ran. (501 ms total) [ PASSED ] 1 test. ``` --- csrc/fusion.cpp | 4 ++-- csrc/tensor_view.cpp | 2 +- tests/python/test_python_frontend.py | 9 ++------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 222a3b1afb6..459ce465692 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -521,12 +521,12 @@ void Fusion::printMath(bool from_outputs_only) { auto exprs_for_print = exprs(); debug() << "Inputs:" << std::endl; for (auto inp : inputs()) { - debug() << " " << inp << ", " << inp->getDataType().value() << std::endl; + debug() << " " << inp << std::endl; } debug() << "Outputs:" << std::endl; for (auto out : outputs()) { - debug() << " " << out << ", " << out->getDataType().value() << std::endl; + debug() << " " << out << std::endl; } // If we want everything in the fusion, grab all values without uses to diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 8a34db6b7aa..d96bd45afda 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -63,7 +63,7 @@ std::string TensorView::toString(int indent_size) const { default: NVF_ERROR(false, "Unknown tensor memory type."); } - ss << domain()->toString(indent_size); + ss << "_" << dtype() << domain()->toString(indent_size); if (getComputeAtPosition() > 0) { ss << " ca_pos( "; diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index f52ee06aa65..664b254c8c6 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4249,9 +4249,7 @@ def fusion_func(fd: FusionDefinition): with pytest.raises( Exception, - match=re.escape( - "Expected input 0, T0_g[ iS0{i0} ], to be an at::Tensor but got scalar 2" - ), + match="Expected input 0, .*, to be an at::Tensor but got scalar 2", ): nvf_out = fd.execute([scalar_inp, scalar_inp]) @@ -4265,10 +4263,7 @@ def fusion_func(fd: FusionDefinition): with pytest.raises( Exception, - match=re.escape( - "Expected input 0, T0_g[ iS0{i0} ], to be bound to a tensor of dtype float," - " but got a tensor of dtype __half" - ), + match="Expected input 0, .*, to be bound to a tensor of dtype float, but got a tensor of dtype __half", ): wrong_tensor_inp = torch.rand((15,), dtype=torch.float16, device="cuda:0") nvf_out = fd.execute([wrong_tensor_inp, 2.0]) From c7c67c4de45aed067272f8d5f19ad87d8d7d547d Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 3 Sep 2024 20:11:37 -0700 Subject: [PATCH 52/54] Print sizes when mismatch. (#2896) --- tests/cpp/validator.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/cpp/validator.cpp b/tests/cpp/validator.cpp index 2914f645207..36bd8555146 100644 --- a/tests/cpp/validator.cpp +++ b/tests/cpp/validator.cpp @@ -97,7 +97,10 @@ void testValidate( static_cast( TensorDomain::noReductions(out_tv->getLogicalDomain()) .size()), - "Dimensionality mismatch in outputs."); + "Dimensionality mismatch in outputs: ", + aten_output_tensor.sizes(), + " vs ", + fusion_output_tensor.sizes()); auto tolerance_values = getTolerance(out_tv->getDataType().value(), reduction_size, tolerances); From d4dd6d381d3881067b84ef324d1b2d9f2ab9a731 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:57:48 -0400 Subject: [PATCH 53/54] Introduce multi matmul scheduler (#2719) This is the first in a series of PRs that represents a refactor of the matmul scheduler to increase the flexibility in the number of operands and matmul patterns we can schedule in a single kernel. This PR: - Introduces the `MultipleMatmulScheduler` class which will eventually implement and extend the functionality in the current `scheduleMatmul`. - Introduces a new option `EnableOption::FuseMultipleMatmuls` which enables this scheduler in place of the default matmul scheduler. - Implements some basic functionality in `MultipleMatmulScheduler`. Enough to translate matmul patterns and define most of the relevant tensors (`splitk_sum` and `smem_epilogue` will be defined and all tensors will be scheduled in later PRs). In the next PRs I will - Schedule the prologue and introduce parametrized tests to check that the scheduled fusion matches the current scheduler - Schedule epilogue tensors without split-K and smem_epilogue - Schedule split-K sum - Schedule smem_epilogue - Update the heuristic to adjust parameters when multiple matmul patterns are detected. This is needed to ensure we don't use register/smem buffers that are too large. - Update canSchedule checks to accept multiple matmuls when the multi matmul scheduler is enabled --- CMakeLists.txt | 1 + csrc/options.cpp | 1 + csrc/options.h | 1 + csrc/scheduler/matmul.cpp | 6 + csrc/scheduler/mma_utils.cpp | 29 +++ csrc/scheduler/mma_utils.h | 8 + csrc/scheduler/multi_matmul.cpp | 423 ++++++++++++++++++++++++++++++++ csrc/scheduler/multi_matmul.h | 20 ++ 8 files changed, 489 insertions(+) create mode 100644 csrc/scheduler/multi_matmul.cpp create mode 100644 csrc/scheduler/multi_matmul.h diff --git a/CMakeLists.txt b/CMakeLists.txt index b043b7f1ca6..575ae57a8eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,6 +205,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/heuristic_types.cpp ${NVFUSER_SRCS_DIR}/scheduler/mark_aliases.cpp ${NVFUSER_SRCS_DIR}/scheduler/matmul.cpp + ${NVFUSER_SRCS_DIR}/scheduler/multi_matmul.cpp ${NVFUSER_SRCS_DIR}/scheduler/matmul_heuristic_plugin.cpp ${NVFUSER_SRCS_DIR}/scheduler/matmul_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp diff --git a/csrc/options.cpp b/csrc/options.cpp index 8618384ccf1..2641216bd23 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -153,6 +153,7 @@ std::unordered_map> Options< EnableOption>::getOptionsFromEnv() { const std::unordered_map available_options = { {"fuse_matmul", EnableOption::FuseMatmul}, + {"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls}, {"id_model", EnableOption::IdModel}, {"kernel_db", EnableOption::KernelDb}, {"kernel_profile", EnableOption::KernelProfile}, diff --git a/csrc/options.h b/csrc/options.h index 2d8a48a0ec0..b1b706e8f15 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -92,6 +92,7 @@ enum class DebugDumpOption { //! enum class EnableOption { FuseMatmul, //! Enable automatic fusion of matmul and linear ops + FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel IdModel, //! Enable IdModel KernelDb, //! Enable Kernel Database KernelProfile, //! Enable intra-kernel performance profiling diff --git a/csrc/scheduler/matmul.cpp b/csrc/scheduler/matmul.cpp index 0191acf3a91..cf377e26e92 100644 --- a/csrc/scheduler/matmul.cpp +++ b/csrc/scheduler/matmul.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include // NOTE: included to avoid compilation error caused by missing destructor in @@ -734,6 +735,11 @@ void scheduleSplitKSum( } // namespace void scheduleMatmul(Fusion* fusion, const MatmulParams& params) { + if (isOptionEnabled(EnableOption::FuseMultipleMatmuls)) { + scheduleMultipleMatmuls(fusion, params); + return; + } + FusionGuard fg(fusion); // Make sure we don't have global memory set on intermediate tensors from diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index bd67a01caf1..268b50dd4bd 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1258,6 +1258,7 @@ TensorRolesMapOpt getTensorRoles( Fusion* fusion, const IdModel& id_model, const DimRolesMap& dim_roles) { + NVF_ERROR(fusion != nullptr); const auto mma_input_candidates = ir_utils::filterByType(fusion->inputs()).vector(); if (mma_input_candidates.empty()) { @@ -2012,6 +2013,34 @@ std::vector canonicalDimOrdering( return ordering; } +std::optional> allPatternRoles( + IdModel& id_model, + const std::vector& patterns) { + Fusion* fusion = nullptr; + DimRolesMap id_roles; + for (const MatmulPattern& pattern : patterns) { + if (fusion == nullptr) { + fusion = pattern.output->fusion(); + } else { + NVF_ERROR(fusion == pattern.output->fusion()); + } + mma_utils::DimRolesMap pattern_id_roles = pattern.getDimRoles(id_model); + for (const auto& [g, role] : pattern_id_roles) { + const auto& [it, inserted] = id_roles.try_emplace(g, role); + if (!inserted && it->second != role) { + return std::nullopt; + } + } + } + const auto tensor_roles_opt = + mma_utils::getTensorRoles(fusion, id_model, id_roles); + if (!tensor_roles_opt.isValid()) { + return std::nullopt; + } + return std::pair{ + id_roles, tensor_roles_opt.getData()}; +} + } // namespace mma_utils } // namespace nvfuser diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index b5f12547d7c..023c05ed2f5 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -431,6 +431,14 @@ std::vector canonicalDimOrdering( const mma_utils::DimRolesMap& dim_roles, const ValGraph& permissive_graph); +//! Returns roles maps which have been merged across individual maps generated +//! by the provided matmul patterns. +//! +//! Returns std::nullopt if two patterns have incompatible roles +std::optional> allPatternRoles( + IdModel& id_model, + const std::vector& patterns); + } // namespace mma_utils } // namespace nvfuser diff --git a/csrc/scheduler/multi_matmul.cpp b/csrc/scheduler/multi_matmul.cpp new file mode 100644 index 00000000000..6de3ad348ce --- /dev/null +++ b/csrc/scheduler/multi_matmul.cpp @@ -0,0 +1,423 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// NOTE: included to avoid compilation error caused by missing destructor in +// 'SchedulerRuntimeInfo' +#include +#include "mma_type.h" + +namespace nvfuser { + +namespace { + +// A matmul kernel might perform multiple matmuls; i.e. there can be multiple +// MmaOps in the scheduled tensor. Each one outputs a TensorView* which we call +// an mma_result. Each MmaOp will also have two input TensorViews which we call +// "ab" and "bb" since they are the immediate A and B operands and they contain +// broadcast dimensions. Again there can be multiple abs and multiple bbs in +// one fusion. These TensorViews are loaded from global memory tensors that we +// call "a" and "b" into shared memory tensors called acw_smem and bcw_smem. +// They are loaded from shared memory to register buffers we call "acr" and +// "bcr" ("cr" meaning "cache read" in this context). +// +// Putting this all together we have the following order for a simple matmul +// +// a -> acw_smem -> acr -> ... -> ab +// \ +// mma_result -> ... -> dc -> d +// / +// b -> bcw_smem -> bcr -> ... -> bb +// +// The ... indicate that there might be other tensors involved in a prologue or +// epilogue section at that location. +// +// In this example there are two matmuls both using the same "a" operand: +// +// b1 -> bcw_smem1 -> bcr1 -> ... -> bb1 +// \ +// mma_result1 +// / \ +// a -> acw_smem -> acr -> ... -> ab ... -> dc -> d +// \ / +// mma_result2 +// / +// b2 -> bcw_smem2 -> bcr2 -> ... -> bb2 +// +// Note that there can be more than one output d and each one will have its own +// register cache dc. +// +// Split-K and smem epilogue unswizzling add two additional tensors for each +// mma in the fusion: splitk_sum and smem_epilogue. +// +// // No split-K, no smem epilogue unswizzling: +// mma_result -> ... -> dc -> d +// // split-K, no smem epilogue unswizzling: +// mma_result -> splitk_sum -> ... -> dc -> d +// // smem epilogue unswizzling, no split-K: +// mma_result -> smem_epilogue -> ... -> dc -> d +// // split-K and smem epilogue unswizzling: +// mma_result -> smem_epilogue -> splitk_sum -> ... -> dc -> d +// +// These additional tensors are added to each mma_result in the fusion. +// +// Each of the named tensors above is scheduled differently. We schedule them +// by building AbstractTensors for each tensor category; these are held in +// MultipleMatmulScheduler::schedules_. +class MultipleMatmulScheduler { + public: + MultipleMatmulScheduler(Fusion* fusion, const MatmulParams& params) + : fusion_(fusion), + params_(params), + id_model_(fusion, /*build_graphs=*/false) {} + + void run() { + // Clears memory spaces on intermediate tensors, calls + // cache{After,Before,Fork} on inputs and outputs + cacheInputsAndOutputs(); + + // Finds matmul patterns and translates them to MmaOps, then finds tensor + // and dimension roles for all tensors in the fusion + findPatterns(); + translatePatterns(); + // translatePatterns changes the TensorView graph, so we build the IdModel + // afterward + buildIdModel(); + findRoles(); + + // Defines acw_smem/bcw_smem and acr/bcr by possibly calling cacheAfter. + // This also collects mma_results_ + defineOperandCaches(); + + // TODO: Remove this as the methods below are implemented + return; + + // Schedules: + // - global->smem (cp.async) + // - smem->register (ldmatrix) + // - prologue computation in registers, including broadcast to e.g. + // ab=[iM, bN, iK] + schedulePrologues(); + + // schedule mma instruction output (mma_result) + scheduleMmaResults(); + + // schedule epilogue + scheduleEpilogue(); + + // schedule splitk_sum + scheduleSplitKSum(); + + setUpInlining(); + + // set up circular buffering. This must come after everything up to + // mma_result is scheduled, since everything in the main loop will need to + // be rotated + setUpCircularBuffering(); + } + + private: + void cacheInputsAndOutputs() { + // Make sure we don't have global memory set on intermediate tensors from + // fusion segmentation + scheduler_utils::clearMemorySpace(fusion_); + + // Cache inputs + scheduler_utils::cacheInputs(fusion_, /*unroll=*/true); + + // Cache and fork outputs + cached_outputs_ = + scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true); + } + + void findPatterns() { + patterns_ = mma_utils::findMatmulPatterns(fusion_); + NVF_ERROR(!patterns_.empty(), "No matmul patterns were found"); + } + + void countDims() { + NVF_ERROR(!patterns_.empty()); + TensorView* mma_result = patterns_.front().output; + num_device_dims_ = numDeviceDims(mma_result); + for (const auto& it : id_roles_) { + if (it.second == MatmulDimRole::Batch) { + // All batch dims will be merged into one, if any exist + num_local_batch_dims_ = 1; + } + } + num_splitk_dims_ = params_.splitk_factor > 1 ? 1 : 0; + // Subtract 6 for the [Mo, No, Ko, Mi, Ni, Ki] + num_device_and_batch_dims_ = num_device_dims_ + num_local_batch_dims_; + } + + void translatePatterns() { + mma_results_.reserve(patterns_.size()); + for (mma_utils::MatmulPattern& pattern : patterns_) { + MmaOp* mma = pattern.translateToMmaOp(); + mma_results_.push_back(mma->out()->as()); + } + } + + void buildIdModel() { + id_model_ = IdModel(fusion_, /*build_graphs=*/false); + id_model_.buildPermissiveGraph(); + graph_ = &id_model_.idGraph(IdMappingMode::PERMISSIVE); + } + + // Get tensor roles and id roles + // When there are multiple matmul patterns, we can have conflicting roles. + // For now we throw an error if this is the case. + // TODO: This should be checked in canScheduleCompileTime + void findRoles() { + const auto roles_opt = mma_utils::allPatternRoles(id_model_, patterns_); + NVF_ERROR( + roles_opt.has_value(), + "Incompatible roles found between matmul patterns"); + std::tie(id_roles_, tensor_roles_) = roles_opt.value(); + + mma_utils::MatmulOperandInnerDimsOpt inner_dims_opt = + mma_utils::getOperandInnerDims(id_model_, id_roles_, tensor_roles_); + NVF_ERROR(inner_dims_opt.isValid(), inner_dims_opt.getErrorMsg()); + inner_dims_ = inner_dims_opt.getData(); + + as_ = tensor_roles_.at(MatmulTensorRole::OPERAND_A); + bs_ = tensor_roles_.at(MatmulTensorRole::OPERAND_B); + + countDims(); + } + + // Including current tensor naming convention for reference, + // this is very temporary and will change over time and + // in fact the whole body of this function will + // eventually be a set of utility functions for different + // sections of matmul(fusion) kernels, with + // each having its own build out to do. + // + // Current naming convention is based on the following formula: + // + // d = alpha * (a x b) + beta * c + // + // and is defined in the following way: + // + // operands assumed in global memory : a, b, c + // + // registers staging global load : ar, br (short for a/b read) + // + // shared mem cache of operands : acw_smem, bcw_smem (short for a/b + // cache_write smem) + // + // registers at shared memory load output : acr, bcr (short for a/b cache + // read) + // + // register tensor input to the actual mma op: ab, bb (short for a/b + // broadcasted) + // + // accumulator register: mma_result + // - mma_result is MmaOp output if there is epilogue + // - mma_result is dc (short for d cache) if there is no epilogue + // + // result in global memory: d + + // Currently the support is for a, b, c and d as fusion inputs/outputs + // aka. no prolog fusion yet. + void defineOperandCaches() { + cacheOperandsToSmem(as_, acw_smems_, params_.supported_vec_size.a); + addSetsForCacheReads(acw_smems_, acrs_); + + cacheOperandsToSmem(bs_, bcw_smems_, params_.supported_vec_size.b); + addSetsForCacheReads(bcw_smems_, bcrs_); + + // Now that we are finished possibly redefining the inputs to the MmaOps, + // we can set the macro for those ops + for (TensorView* mma_result : mma_results_) { + MmaOp* mma = dynamic_cast(mma_result->definition()); + NVF_ERROR(mma != nullptr); + mma->setMacro(params_.mma_macro); + } + } + + void cacheOperandsToSmem( + const std::vector& operands, + std::vector& smem_operands, + int64_t vec_size) { + // Use cp.async as requested in scheduler params. + smem_operands.resize(operands.size(), nullptr); + for (size_t i : c10::irange(operands.size())) { + TensorView* operand = operands[i]; + CacheOp cache_op = CacheOp::Unspecified; + if (params_.async_gmem_load_operands) { + int64_t vec_bytes = vec_size * dataTypeSize(operand->dtype()); + NVF_CHECK( + vec_bytes == 4LL || vec_bytes == 8LL || vec_bytes == 16LL, + "Unsupported async vectorization size ", + vec_size, + " = ", + vec_bytes, + " bytes for operand ", + operand->toString(), + " which has data type ", + operand->dtype(), + ". Size must be 4, 8, or 16 bytes. ", + "MatmulParams::async_gmem_load_operands should be set to false in this case."); + cache_op = vec_bytes == 16LL ? CacheOp::Global : CacheOp::AllLevels; + }; + + NVF_ERROR(operand->uses().size() == 1); + smem_operands[i] = ir_utils::consumerTvsOf(operand).at(0); + + LoadStoreOpType load_op = params_.async_gmem_load_operands + ? LoadStoreOpType::CpAsync + : LoadStoreOpType::Set; + + smem_operands[i]->definition()->as()->setOpType(load_op); + smem_operands[i]->definition()->as()->setCacheOp(cache_op); + smem_operands[i]->setMemoryType(MemoryType::Shared); + } + } + + // We add two LoadStore operators to the inputs of our fusions. The first + // one is for a read from global memory and the second one (below) is for a + // cache read. As an optimizaton, we avoid adding an operator if there's an + // existing LoadStoreOp present. Please note that for the second LoadStore + // we don't propagate the allocation domain, since the scheduler sets the + // allocation domain in the registers. + void addSetsForCacheReads( + const std::vector& tv_smems, + std::vector& tv_rs) { + tv_rs.resize(tv_smems.size(), nullptr); + for (size_t i : c10::irange(tv_smems.size())) { + TensorView* tv_smem = tv_smems[i]; + TensorView*& tv_r = tv_rs[i]; + + // There can be multiple uses for example if we have A @ B1 + A @ B2 + // then A will be cached to smem then it might be loaded into two + // separate register buffers, one for each mma. Instead, we will load + // it once into registers then re-use the register buffer for both + // mmas. + if (auto ldst = dynamic_cast(tv_smem->uses().at(0)); + ldst && tv_smem->uses().size() == 1) { + tv_r = ldst->out()->as(); + ldst->setOpType(LoadStoreOpType::LdMatrix); + } else { + tv_r = cacheAfter( + tv_smem, + LoadStoreOpType::LdMatrix, + CacheOp::Unspecified, + /*propagate_allocation_domain=*/false); + } + } + } + + //! This calls orig->cacheAfter() and also updates the permissive graph to + //! reflect the new IterDomain mappings + TensorView* cacheAfter( + TensorView* orig, + LoadStoreOpType op_type = LoadStoreOpType::Set, + CacheOp cache_op = CacheOp::AllLevels, + bool propagate_allocation_domain = false) { + const std::vector orig_alloc = + orig->getMaybeAllocationDomain(); + + TensorView* c = + orig->cacheAfter(op_type, cache_op, propagate_allocation_domain); + + if (propagate_allocation_domain) { + const std::vector cache_alloc = + c->getMaybeAllocationDomain(); + NVF_ERROR(orig_alloc.size() == cache_alloc.size()); + for (size_t i : c10::irange(orig_alloc.size())) { + ValGroup vg = graph_->toGroup(orig_alloc[i]); + graph_->initializeVal(cache_alloc[i], vg); + } + } + + const std::vector orig_logical = + TensorDomain::noReductions(orig->getLogicalDomain()); + const std::vector cache_logical = c->getLogicalDomain(); + // in split-K we do rFactor which gives us a full = sum(partial) + // where partial has root domain that matches the logical domain of the + // original tensor. The logical domain contains Iteration transforms of the + // Reduction axis in the original mma output. + NVF_ERROR(orig_logical.size() == cache_logical.size()); + for (size_t i : c10::irange(orig_logical.size())) { + ValGroup vg = graph_->toGroup(orig_logical[i]); + graph_->initializeVal(cache_logical[i], vg); + } + + return c; + } + + void scheduleMmaResults() { + NVF_ERROR(false, "scheduleMmaResults is not yet implemented"); + } + + void schedulePrologues() { + NVF_ERROR(false, "schedulePrologues is not yet implemented"); + } + + void scheduleEpilogue() { + NVF_ERROR(false, "scheduleEpilogue is not yet implemented"); + } + + void scheduleSplitKSum() { + NVF_ERROR(false, "scheduleSplitKSum is not yet implemented"); + } + + void setUpInlining() { + NVF_ERROR(false, "setUpInlining is not yet implemented"); + } + + // NOTE: this should be called after acw_smem, acr, ..., ab, and mma_result + // transforms have been applied and inlining + void setUpCircularBuffering() { + NVF_ERROR(false, "setUpCircularBuffering is not yet implemented"); + } + + private: + Fusion* fusion_; + const MatmulParams& params_; + IdModel id_model_; + // Permissive graph of id_model_, which we modify at times using e.g. + // AbstractTensor.split or by mapping vals in cacheAfter and rFactor + ValGraph* graph_ = nullptr; + std::vector patterns_; + mma_utils::DimRolesMap id_roles_; + mma_utils::TensorRolesMap tensor_roles_; + mma_utils::MatmulOperandInnerDims inner_dims_; + + int64_t num_splitk_dims_ = 0, num_device_dims_ = 0, num_local_batch_dims_ = 0, + num_device_and_batch_dims_ = 0; + + std::vector> cached_outputs_; + + std::vector as_, bs_, acw_smems_, bcw_smems_, acrs_, bcrs_, abs_, + bbs_, mma_results_, splitk_sums_, smem_epilogues_; +}; + +} // namespace + +void scheduleMultipleMatmuls(Fusion* fusion, const MatmulParams& params) { + FusionGuard fg(fusion); + + MultipleMatmulScheduler(fusion, params).run(); +} + +} // namespace nvfuser diff --git a/csrc/scheduler/multi_matmul.h b/csrc/scheduler/multi_matmul.h new file mode 100644 index 00000000000..3b8f060a793 --- /dev/null +++ b/csrc/scheduler/multi_matmul.h @@ -0,0 +1,20 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include + +namespace nvfuser { + +NVF_API void scheduleMultipleMatmuls( + Fusion* fusion, + const MatmulParams& params); + +} // namespace nvfuser From bcdd18158404e050f273a8641cee8f3e24af42d5 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Wed, 4 Sep 2024 09:42:37 -0700 Subject: [PATCH 54/54] Allocation changes for TMA Circular Buffering (#2824) ## Summary ## It is the changes to the allocation lowering pass from https://github.com/NVIDIA/Fuser/pull/2773. ## Details ## ### GpuLower ### - `ldst_mbarrier_token_map_` maps `LoadStoreOp` to mbarrier tokens, which are represented as `TensorView` of number of pipeline stages. - `mbarrier_token_smem_alloc_set_` tracks the `kir::Allocate` expressions for the mbarriers and their tokens. - `ldst_mbarrier_index_map_` maps the cloned `LoadStoreOp` in the prologue and main loops to their indexed mbarrier. ### Allocation ### - In the allocation pass, create shared memory allocations and operations around `LoadStoreOp` expression. ```cpp // Created tokens, mbarriers, init, and inval operations in allocation pass. for (circular_buffer_loop) { __shared__ int64_t tokens[num_stages]; __shared__ int64_t mbarrier[num_stages]; init(mbarrier); cp.async.bulk(data, mbarrier); inval(mbarrier); } ``` ## AliasMemory ## - The mbarrier and its token are mapped together. The token is the mbarrier state of the last phase. For simplicity, mark token liveness when mbarrier is initialized and invalidated. - Apply `markWrite` for mbarrier and its token when the expression is `MBarrierInit` - Apply `markRead` for mbarrier and its token when the expression is `MBarrierInvalidate` --- csrc/device_lower/lower2device.h | 39 +++++++ csrc/device_lower/pass/alias_memory.cpp | 49 +++++++-- csrc/device_lower/pass/allocation.cpp | 140 +++++++++++++++++++----- 3 files changed, 190 insertions(+), 38 deletions(-) diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 54626da3586..3c2eb0d0ebb 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -226,6 +226,32 @@ class GpuLower : public NonCopyable { return ldst_mbarrier_map_; } + std::unordered_map& ldstMBarrierTokenMap() { + return ldst_mbarrier_token_map_; + } + + const std::unordered_map& ldstMBarrierTokenMap() + const { + return ldst_mbarrier_token_map_; + } + + std::unordered_set& mBarrierTokenSmemAllocSet() { + return mbarrier_token_smem_alloc_set_; + } + + const std::unordered_set& mBarrierTokenSmemAllocSet() const { + return mbarrier_token_smem_alloc_set_; + } + + std::unordered_map& ldstMBarrierIndexMap() { + return ldst_mbarrier_index_map_; + } + + const std::unordered_map& + ldstMBarrierIndexMap() const { + return ldst_mbarrier_index_map_; + } + bool isNvFuserZeroEnabled() { if (isOptionDisabled(DisableOption::MagicZero)) { return false; @@ -359,6 +385,19 @@ class GpuLower : public NonCopyable { //! for vectorization. std::vector> validations_; + // Keep track of placeholders for tokens returned by arrive/expected tx + // mbarrier operations for each load/store operation that requires such + // synchronization + std::unordered_map ldst_mbarrier_token_map_; + + // Collection of kir::Allocate for smem buffers used for mbarrier and token + // objects from cpAsyncBulk synchronization + std::unordered_set mbarrier_token_smem_alloc_set_; + + // Keep track what mbarrier object is used in load/store operation that + // requires such synchronization, required by indexing pass + std::unordered_map ldst_mbarrier_index_map_; + Fusion* fusion_ = nullptr; // A temporary flag which is true if the fusion uses any feature that requires diff --git a/csrc/device_lower/pass/alias_memory.cpp b/csrc/device_lower/pass/alias_memory.cpp index f2df2f0824a..59474f8ae63 100644 --- a/csrc/device_lower/pass/alias_memory.cpp +++ b/csrc/device_lower/pass/alias_memory.cpp @@ -866,20 +866,47 @@ class AllocationInfoMap : private kir::IrVisitor { } void collectLivenessInfoOfExprMBarrier(Expr* expr) { - const auto expr_pos = scope_map_.getExprPos(expr); + int64_t expr_pos = scope_map_.getExprPos(expr); + auto mark_liveness = [&expr_pos, this](TensorView* tv, bool is_write) { + AllocationInfo* alloc_info = getAllocInfoFromTV(tv); + if (is_write) { + alloc_info->inner_live_interval->markWrite(expr_pos); + } else { + alloc_info->inner_live_interval->markRead(expr_pos); + } + ScopeInfo* outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); + int64_t outer_pos = + outer_loop_info ? outer_loop_info->start_pos : expr_pos; + if (is_write) { + alloc_info->outer_live_interval->markWrite(outer_pos); + } else { + alloc_info->outer_live_interval->markRead(outer_pos); + } + }; + + // The liveness of the mbarrier and its token are mapped together. + // The token is the mbarrier state of the last phase. if (auto init = dynamic_cast(expr)) { - auto alloc_info = getAllocInfoFromTV(init->mbarrier()->as()); - alloc_info->inner_live_interval->markWrite(expr_pos); - auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); - auto write_pos = outer_loop_info ? outer_loop_info->start_pos : expr_pos; - alloc_info->outer_live_interval->markWrite(write_pos); + mark_liveness(init->mbarrier()->as(), /*is_write=*/true); + + // Register start of lifetime for a mbarrier token returned by + // MBarrierArriveExpectTx and MBarrierArrive. + if (GpuLower::current()->ldstMBarrierTokenMap().count(expr) > 0) { + mark_liveness( + GpuLower::current()->ldstMBarrierTokenMap()[expr], + /*is_write=*/true); + } } else if (auto inval = dynamic_cast(expr)) { - auto alloc_info = getAllocInfoFromTV(inval->mbarrier()->as()); - alloc_info->inner_live_interval->markRead(expr_pos); - auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info); - auto write_pos = outer_loop_info ? outer_loop_info->start_pos : expr_pos; - alloc_info->outer_live_interval->markRead(write_pos); + mark_liveness(inval->mbarrier()->as(), /*is_write=*/false); + + // Register end of lifetime for a mbarrier token returned by + // returned by MBarrierArriveExpectTx and MBarrierArrive + if (GpuLower::current()->ldstMBarrierTokenMap().count(expr) > 0) { + mark_liveness( + GpuLower::current()->ldstMBarrierTokenMap()[expr], + /*is_write=*/false); + } } } diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index d78d909c8dd..f4be2dd804d 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -354,6 +354,8 @@ class AllocationInserter : public kir::ExprMutator { return; } + int64_t circular_buffer_depth = 1; + // Found where the allocation needs to be inserted for (const auto i : c10::irange(expr->outputs().size())) { @@ -425,6 +427,16 @@ class AllocationInserter : public kir::ExprMutator { auto alloc_expr = createAllocExpr(allocation, is_output); auto init_expr = createInitExpr(allocation, init); + // Check that all circular buffer depth match + if (out_tv->isCircularBuffered() && circular_buffer_depth == 1) { + circular_buffer_depth = out_tv->circularBufferDepth(); + } + NVF_ERROR( + circular_buffer_depth == 1 || + circular_buffer_depth == out_tv->circularBufferDepth(), + "Expected all output TensorViews for the same expression ", + "to have the same circular_buffer_depth"); + // Write information to GPULower writeInfoToGPULower(allocation, alloc_expr); @@ -458,33 +470,107 @@ class AllocationInserter : public kir::ExprMutator { // solution, we should remove this after we have a better way to handle // synchronizations for cp.async.bulk. if (ir_utils::isCpAsyncBulkLoad(expr)) { - // create and allocate a memory barrier - TensorView* mbarrier = TensorViewBuilder() - .shape(std::vector{}) - .dtype(DataType::UInt) - .contiguity(true) - .build(); - mbarrier->setMemoryType(MemoryType::Shared); - auto mbarrier_init = IrBuilder::create( - mbarrier, - simplifyExpr(SimplifyingIrBuilder::maybeCastExpr( - DataType::UInt32, - lower_utils::getNumThreadsInTensorView( - expr->output(0)->as())))); - auto sync_init = IrBuilder::create(); - auto mbarrier_inval = - IrBuilder::create(mbarrier); - auto sync_inval = IrBuilder::create(); - - kir::Allocate* mbarrier_alloc = - IrBuilder::create(mbarrier, MemoryType::Shared); - Scope* expr_scope = scope_.empty() ? nullptr : scope_.back(); - registerInsertBefore(expr, mbarrier_alloc, expr_scope); - registerInsertBefore(expr, mbarrier_init, expr_scope); - registerInsertBefore(expr, sync_init, expr_scope); - registerInsertAfter(expr, mbarrier_inval, expr_scope); - registerInsertAfter(expr, sync_inval, expr_scope); - GpuLower::current()->ldstMBarrierMap()[expr] = mbarrier; + if (circular_buffer_depth > 1) { + // Create and allocate a memory barrier. If this is a circular buffer, + // then allocate an array of mbarier objects. mbarrier::init and + // mbarrier::inval will be updated in circular buffering pass, but we + // add them here to handle shared memory correctly in alias memory pass. + TensorView* mbarrier = + TensorViewBuilder() + .shape(std::vector{circular_buffer_depth}) + .dtype(DataType::UInt) + .contiguity(true) + .build(); + mbarrier->setMemoryType(MemoryType::Shared); + + // The wait condition for mbarrier is a single thread and the expected + // number of transaction bytes + kir::MBarrierInit* mbarrier_init = IrBuilder::create( + mbarrier, expr->container()->oneVal(DataType::UInt32)); + + kir::Allocate* mbarrier_alloc = + IrBuilder::create(mbarrier, MemoryType::Shared); + + Scope* expr_scope = scope_.empty() ? nullptr : scope_.back(); + + kir::MBarrierInvalidate* mbarrier_inval = + IrBuilder::create(mbarrier); + + // For circular buffers we need to prepare a placeholder for the + // tokens created by 'MBarrierArriveExpectTx' IR node. The tokens are + // placed in shared memory and used by threads in a block. + TensorView* mbarrier_tokens = + TensorViewBuilder() + .shape(std::vector{circular_buffer_depth}) + .dtype(DataType::UInt) + .contiguity(true) + .build(); + mbarrier_tokens->setMemoryType(MemoryType::Shared); + + kir::Allocate* mbarrier_tokens_alloc = IrBuilder::create( + mbarrier_tokens, MemoryType::Shared); + + // Add tokens, mbarriers, init, and inval operations around tma + // expression like this: + // + // for (circular_buffer_loop) { + // __shared__ tokens[num_stages]; + // __shared__ mbarrier[num_stages]; + // init(mbarrier); + // cp.async.bulk(data, mbarrier); + // inval(mbarrier); + // } + + // NOTE: Block sync ir node is not added here. It will be added in the + // circular buffering pass + registerInsertBefore(expr, mbarrier_tokens_alloc, expr_scope); + registerInsertBefore(expr, mbarrier_alloc, expr_scope); + registerInsertBefore(expr, mbarrier_init, expr_scope); + registerInsertAfter(expr, mbarrier_inval, expr_scope); + + // Map LoadStoreOp expression to ir nodes created in this pass + GpuLower::current()->ldstMBarrierMap()[expr] = mbarrier; + GpuLower::current()->ldstMBarrierTokenMap()[expr] = mbarrier_tokens; + // Register tokens placeholder for MBarrierInit and MBarrierInvalidate, + // needed to manage life time of smem buffor in alias memory + GpuLower::current()->ldstMBarrierTokenMap()[mbarrier_init] = + mbarrier_tokens; + GpuLower::current()->ldstMBarrierTokenMap()[mbarrier_inval] = + mbarrier_tokens; + // Keep track of kir::Allocate for mBarrier and token objects, + // to simplify circular buffering pass logic + GpuLower::current()->mBarrierTokenSmemAllocSet().insert(mbarrier_alloc); + GpuLower::current()->mBarrierTokenSmemAllocSet().insert( + mbarrier_tokens_alloc); + } else { + // create and allocate a memory barrier + TensorView* mbarrier = TensorViewBuilder() + .shape(std::vector{}) + .dtype(DataType::UInt) + .contiguity(true) + .build(); + mbarrier->setMemoryType(MemoryType::Shared); + auto mbarrier_init = IrBuilder::create( + mbarrier, + simplifyExpr(SimplifyingIrBuilder::maybeCastExpr( + DataType::UInt32, + lower_utils::getNumThreadsInTensorView( + expr->output(0)->as())))); + auto sync_init = IrBuilder::create(); + auto mbarrier_inval = + IrBuilder::create(mbarrier); + auto sync_inval = IrBuilder::create(); + + kir::Allocate* mbarrier_alloc = + IrBuilder::create(mbarrier, MemoryType::Shared); + Scope* expr_scope = scope_.empty() ? nullptr : scope_.back(); + registerInsertBefore(expr, mbarrier_alloc, expr_scope); + registerInsertBefore(expr, mbarrier_init, expr_scope); + registerInsertBefore(expr, sync_init, expr_scope); + registerInsertAfter(expr, mbarrier_inval, expr_scope); + registerInsertAfter(expr, sync_inval, expr_scope); + GpuLower::current()->ldstMBarrierMap()[expr] = mbarrier; + } } }