diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 66167992fc5..cea33579fae 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -543,19 +544,22 @@ SyncMap::SyncMap(Fusion* fusion) { ProducerConsumerIndexingInfoCache indexing_info(producer, consumer); - const auto p2c_map_no_forwarding = - BestEffortReplay( - consumer->getLoopDomain(), - producer->getLoopDomain(), - PairwiseLogicalDomainMap(producer, consumer) - .mapProducerToConsumer(), - /*replay_forward_id_map=*/{}, - /*target_forward_id_map=*/{}, - /*skip_replay_swizzle=*/false, - /*skip_target_swizzle=*/false, - /*skip_resize=*/false, - /*error_on_failure=*/false) - .getReplay(); + // P2C map is required when using the IdModel-based analysis + const std::unordered_map + p2c_map_no_forwarding = GpuLower::current()->hasIdModel() + ? BestEffortReplay( + consumer->getLoopDomain(), + producer->getLoopDomain(), + PairwiseLogicalDomainMap(producer, consumer) + .mapProducerToConsumer(), + /*replay_forward_id_map=*/{}, + /*target_forward_id_map=*/{}, + /*skip_replay_swizzle=*/false, + /*skip_target_swizzle=*/false, + /*skip_resize=*/false, + /*error_on_failure=*/false) + .getReplay() + : std::unordered_map{}; // At this point each parallel type that's present in the consumer or // the producer will be present in their corresponding `_parallel_ids` @@ -679,13 +683,8 @@ SyncMap::SyncMap(Fusion* fusion) { continue; } - // Check if the IdModel-based analysis matches with the - // existing analysis - bool requires_sync = true; - - IterDomain* p_loop = nullptr; - IterDomain* c_loop = nullptr; - + // Use the IdModel loop promotion when available. This is + // required for tensors with non-trivial loop domains if (GpuLower::current()->hasIdModel()) { if (producer_ptype == consumer_ptype) { // If p_id and c_id are mapped in BestEffortReplay with @@ -693,7 +692,7 @@ SyncMap::SyncMap(Fusion* fusion) { // synchronization. if (auto it = p2c_map_no_forwarding.find(p_id); it != p2c_map_no_forwarding.end() && it->second == c_id) { - requires_sync = false; + continue; } else { // Even if not mapped in BestEffortReplay, inlining // may effectively make the producer and consumer @@ -704,56 +703,40 @@ SyncMap::SyncMap(Fusion* fusion) { const auto& id_model = GpuLower::current()->idModel(); auto producer_loop_id = indexing_utils::getLoopPromotion(p_id, id_model); - p_loop = producer_loop_id; auto consumer_loop_id = indexing_utils::getLoopPromotion(c_id, id_model); - c_loop = consumer_loop_id; const auto& indexing_traveral_graph = - id_model.idGraph(IdMappingMode::ALMOSTEXACT); + id_model.idGraph(TensorIndexer::traversalGraphType()); if (indexing_traveral_graph.disjointValSets().strictAreMapped( producer_loop_id, consumer_loop_id)) { - requires_sync = false; + continue; } } } - } - - // When the producer is parallelized, the producer and the - // consumer must use the same index with the same parallel - // type. Otherwise, a sync is required. This is not the case - // when this op is a parallel broadcast. - if (producer_parallel_bcast) { - // As long as they are permissively mapped using the same - // parallel type, no communication is required - if (producer_ptype == consumer_ptype && - ca_map->areMapped(p_id, c_id, IdMappingMode::PERMISSIVE)) { + } else { + // When the producer is parallelized, the producer and the + // consumer must use the same index with the same parallel + // type. Otherwise, a sync is required. This is not the case + // when this op is a parallel broadcast. + if (producer_parallel_bcast) { + // As long as they are permissively mapped using the same + // parallel type, no communication is required + if (producer_ptype == consumer_ptype && + ca_map->areMapped(p_id, c_id, IdMappingMode::PERMISSIVE)) { + continue; + } + // Can this happen? NVF_ERROR( - !requires_sync, - expr->toString(), - p_id->toString(), - " (", - p_loop->toString(), - ")", - ", ", - c_id->toString(), - " (", - c_loop->toString(), - ")"); - continue; + false, + "Unexpected case. Producer: ", + producer->toString(), + ", consumer: ", + consumer->toString()); } - // Can this happen? - NVF_ERROR( - false, - "Unexpected case. Producer: ", - producer->toString(), - ", consumer: ", - consumer->toString()); - } - - if (producer_ptype == consumer_ptype) { - if (useSameIndex(producer, p_id, consumer, c_id, indexing_info)) { - NVF_ERROR(!requires_sync); - continue; + if (producer_ptype == consumer_ptype) { + if (useSameIndex(producer, p_id, consumer, c_id, indexing_info)) { + continue; + } } } diff --git a/csrc/id_model/indexing.h b/csrc/id_model/indexing.h index 473eefabeea..8aeb630e9d8 100644 --- a/csrc/id_model/indexing.h +++ b/csrc/id_model/indexing.h @@ -92,8 +92,13 @@ class TensorIndexer { // should not affect actual index exprs. // Returns non-const reference because indexing may create new domains and // need to update the graph. + + static IdMappingMode traversalGraphType() { + return IdMappingMode::ALMOSTEXACT; + } + ValGraph& traversalGraph() const { - return id_model_.idGraph(IdMappingMode::ALMOSTEXACT); + return id_model_.idGraph(traversalGraphType()); } // Traverse exprs and set allocation info for each tensor diff --git a/tests/cpp/test_gpu_view.cpp b/tests/cpp/test_gpu_view.cpp index c991e8607d6..47e10c92546 100644 --- a/tests/cpp/test_gpu_view.cpp +++ b/tests/cpp/test_gpu_view.cpp @@ -2701,11 +2701,9 @@ TEST_F(GpuViewTest, FusionMismatchingReshape) { // Parallelize all tensors as [BIDx, TIDx] schedule.merge(0); schedule.split(0, 128); -#if 0 - // TODO: sync analysis is not working yet + schedule.parallelize(0, ParallelType::BIDx); schedule.parallelize(1, ParallelType::TIDx); -#endif // Now, tv5 looks like: //