From 1158543c9bd16190fe5820b41a7f4c7e891bc5d9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 5 Sep 2024 09:34:43 -0700 Subject: [PATCH] Use loop promotion and indexing traversal graph to find mismatched parallelization (#2875) See #2850 Stacked on #2901 The old code is still used by default. With `NVFUSER_ENABLE=id_model`, the new analysis is used. It's also used for tensors with non-conventional domains. This is required for #2851. It also enables previously disabled parallelization of the mismatching reshape test from #2684. I validated the change by comparing the results between the existing and new analyses with all the tests and benchmarks. The only mismatch was with the mismatching reshape test, for which the existing analysis declared a sync is required, whereas the new one correctly recognizes there's no cross-thread dependency. --- .../analysis/sync_information.cpp | 156 +++++++++++++++--- csrc/device_lower/lower2device.h | 4 + csrc/id_model/indexing.h | 7 +- tests/cpp/test_gpu3.cpp | 27 +++ tests/cpp/test_gpu_view.cpp | 4 +- 5 files changed, 169 insertions(+), 29 deletions(-) diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index bea1c7ac150..dc3c5e1c292 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -7,8 +7,11 @@ // clang-format on #include #include +#include +#include #include #include +#include #include @@ -492,7 +495,7 @@ SyncMap::SyncMap(Fusion* fusion) { producer_redundant_types & (~producer_redundant_use_types); for (const auto producer_i : c10::irange(producer->nDims())) { - auto producer_axis = producer->axis(producer_i); + auto producer_axis = producer->getLoopDomain().at(producer_i); auto producer_ptype = ca_map->getConcreteMappedID(producer_axis, IdMappingMode::LOOP) ->getParallelType(); @@ -516,7 +519,7 @@ SyncMap::SyncMap(Fusion* fusion) { std::vector consumer_parallel_ids( ParallelTypeBitmap::kNumParallelTypes, nullptr); for (const auto consumer_i : c10::irange(consumer->nDims())) { - auto consumer_axis = consumer->axis(consumer_i); + auto consumer_axis = consumer->getLoopDomain().at(consumer_i); auto consumer_ptype = ca_map->getConcreteMappedID(consumer_axis, IdMappingMode::LOOP) ->getParallelType(); @@ -541,6 +544,23 @@ SyncMap::SyncMap(Fusion* fusion) { ProducerConsumerIndexingInfoCache indexing_info(producer, consumer); + // P2C map is required when using the IdModel-based analysis + const std::unordered_map + p2c_map_no_forwarding = GpuLower::current()->hasIdModel() + ? BestEffortReplay( + consumer->getLoopDomain(), + producer->getLoopDomain(), + PairwiseLogicalDomainMap(producer, consumer) + .mapProducerToConsumer(), + /*replay_forward_id_map=*/{}, + /*target_forward_id_map=*/{}, + /*skip_replay_swizzle=*/false, + /*skip_target_swizzle=*/false, + /*skip_resize=*/false, + /*error_on_failure=*/false) + .getReplay() + : std::unordered_map{}; + // At this point each parallel type that's present in the consumer or // the producer will be present in their corresponding `_parallel_ids` // map going from parallel index type (only size 6 for grid/block dims) @@ -653,6 +673,7 @@ SyncMap::SyncMap(Fusion* fusion) { producer->getLogicalDomain(), {p_id}) .empty()) { raw_dims.set(producer_ptype); + continue; } } @@ -662,30 +683,115 @@ SyncMap::SyncMap(Fusion* fusion) { continue; } - // When the producer is parallelized, the producer and the - // consumer must use the same index with the same parallel - // type. Otherwise, a sync is required. This is not the case - // when this op is a parallel broadcast. - - if (producer_parallel_bcast) { - // As long as they are permissively mapped using the same - // parallel type, no communication is required - if (producer_ptype == consumer_ptype && - ca_map->areMapped(p_id, c_id, IdMappingMode::PERMISSIVE)) { - continue; + // Use the IdModel loop promotion when available. This is + // required for tensors with non-trivial loop domains + if (GpuLower::current()->hasIdModel()) { + if (producer_ptype == consumer_ptype) { + // Case 1: + // Producer loop ID: non-broadcast + // Consumer loop ID: non-broadcast + // -> No sync if they are exactly mapped. This case is covered by + // the promotion check. + // + // Case 2: + // Producer loop ID: broadcast (which may be produced by + // merging multiple broadcast domains) + // Consumer loop ID: non-broadcast + // -> They are not exactly mapped but sync is not necessary as + // discussed below. + // + // Case 3: + // Producer loop ID: non-broadcast + // Consumer loop ID: non-broadcast + // -> Sync required if they are not exactly mapped, even when they + // are mapped by the best effort replay. (See + // NVFuserTest.RAWSync for a concrete repro). + + // Case 1 + const auto& id_model = GpuLower::current()->idModel(); + auto producer_loop_id = + indexing_utils::getLoopPromotion(p_id, id_model); + auto consumer_loop_id = + indexing_utils::getLoopPromotion(c_id, id_model); + const auto& indexing_traveral_graph = + id_model.idGraph(TensorIndexer::traversalGraphType()); + if (indexing_traveral_graph.disjointValSets().strictAreMapped( + producer_loop_id, consumer_loop_id)) { + continue; + } + + // Case 2 + // If the producer ID is a broadcast, it does not + // require synchronization even when the producer and + // consumer domains are not promoted to the same + // group. For example, + // + // tv0: [i0] + // tv1: [b1] + // tv2 = tv1 + // tv3 = tv0 + tv2 + // + // tv2->axis(0)->parallelize(ParallelType::TIDx); + // tv3->axis(0)->parallelize(ParallelType::TIDx); + // + // Assume that there's no inlining. Since it isn't + // inlined, the loop domain of tv2 is not mapped with + // that of tv3, thus the avove condition won't + // hit. Still, since tv2 will be executed by all TIDx + // threads independently, there's no need of + // synchronization. + // + // Consider a similar case like below: + // + // tv0: [i0, i1] + // tv1: [i2, b3] + // tv2 = tv1 + // tv3 = tv0 + tv2 + // + // tv2->merge(0, 1); + // tv3->merge(0, 1); + // tv2->axis(0)->parallelize(ParallelType::TIDx); + // tv3->axis(0)->parallelize(ParallelType::TIDx); + // + // This case does require a synchronization since for + // tv2, TIDx will be used to parallelize the outer + // domain only, whereas for tv3 it is mapped to the + // merged domain of the outer and inner domains. In + // other words, if a broadcast becomes non-broadcast + // by getting merged with a non-broadcast domain, it + // requires a synchronization. + if (p_id->isBroadcast()) { + if (auto it = p2c_map_no_forwarding.find(p_id); + it != p2c_map_no_forwarding.end() && it->second == c_id) { + continue; + } + } + } + } else { + // When the producer is parallelized, the producer and the + // consumer must use the same index with the same parallel + // type. Otherwise, a sync is required. This is not the case + // when this op is a parallel broadcast. + if (producer_parallel_bcast) { + // As long as they are permissively mapped using the same + // parallel type, no communication is required + if (producer_ptype == consumer_ptype && + ca_map->areMapped(p_id, c_id, IdMappingMode::PERMISSIVE)) { + continue; + } + // Can this happen? + NVF_ERROR( + false, + "Unexpected case. Producer: ", + producer->toString(), + ", consumer: ", + consumer->toString()); + } + if (producer_ptype == consumer_ptype) { + if (useSameIndex(producer, p_id, consumer, c_id, indexing_info)) { + continue; + } } - // Can this happen? - NVF_ERROR( - false, - "Unexpected case. Producer: ", - producer->toString(), - ", consumer: ", - consumer->toString()); - } - - if (producer_ptype == consumer_ptype && - useSameIndex(producer, p_id, consumer, c_id, indexing_info)) { - continue; } raw_dims.set(producer_ptype); diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 3c2eb0d0ebb..42d94368440 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -112,6 +112,10 @@ class GpuLower : public NonCopyable { return std::const_pointer_cast(compute_at_map_); } + bool hasIdModel() const { + return id_model_.get() != nullptr; + } + IdModel& idModel() { NVF_ERROR(id_model_.get()); return *id_model_; diff --git a/csrc/id_model/indexing.h b/csrc/id_model/indexing.h index 473eefabeea..8aeb630e9d8 100644 --- a/csrc/id_model/indexing.h +++ b/csrc/id_model/indexing.h @@ -92,8 +92,13 @@ class TensorIndexer { // should not affect actual index exprs. // Returns non-const reference because indexing may create new domains and // need to update the graph. + + static IdMappingMode traversalGraphType() { + return IdMappingMode::ALMOSTEXACT; + } + ValGraph& traversalGraph() const { - return id_model_.idGraph(IdMappingMode::ALMOSTEXACT); + return id_model_.idGraph(traversalGraphType()); } // Traverse exprs and set allocation info for each tensor diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 521c1c6d4b7..29190520e44 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -8863,6 +8863,33 @@ TEST_F(NVFuserTest, BestEffortReplayWithMismatchedRootToLogical) { /*error_on_failure=*/false); } +TEST_F(NVFuserTest, RAWSync) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + + tv3->merge(0); + tv2->merge(0); + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDx); + + // Since tv2 is not inlined and tv2 and tv3 are both parallelized, + // tv2 as a producer of tv3 requires a synchronization with tv2 + // placed on shared memory. Lowering the fusion should fail. + EXPECT_THAT( + [&]() { GpuLower(&fusion).run(); }, + testing::ThrowsMessage(testing::HasSubstr( + "Producer is required to be in Global or Shared Memory based on parallelization strategy. RAW flags: (threadIdx.x)"))); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/tests/cpp/test_gpu_view.cpp b/tests/cpp/test_gpu_view.cpp index c991e8607d6..47e10c92546 100644 --- a/tests/cpp/test_gpu_view.cpp +++ b/tests/cpp/test_gpu_view.cpp @@ -2701,11 +2701,9 @@ TEST_F(GpuViewTest, FusionMismatchingReshape) { // Parallelize all tensors as [BIDx, TIDx] schedule.merge(0); schedule.split(0, 128); -#if 0 - // TODO: sync analysis is not working yet + schedule.parallelize(0, ParallelType::BIDx); schedule.parallelize(1, ParallelType::TIDx); -#endif // Now, tv5 looks like: //