Skip to content

Commit

Permalink
cleanup and enable the previously-failing reshape parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Sep 4, 2024
1 parent 1371a85 commit 72ae077
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 65 deletions.
105 changes: 44 additions & 61 deletions csrc/device_lower/analysis/sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <device_lower/analysis/index_compute.h>
#include <device_lower/lower2device.h>
#include <id_model/indexing.h>
#include <id_model/indexing_utils.h>
#include <instrumentation.h>
#include <ir/utils.h>
Expand Down Expand Up @@ -543,19 +544,22 @@ SyncMap::SyncMap(Fusion* fusion) {

ProducerConsumerIndexingInfoCache indexing_info(producer, consumer);

const auto p2c_map_no_forwarding =
BestEffortReplay(
consumer->getLoopDomain(),
producer->getLoopDomain(),
PairwiseLogicalDomainMap(producer, consumer)
.mapProducerToConsumer(),
/*replay_forward_id_map=*/{},
/*target_forward_id_map=*/{},
/*skip_replay_swizzle=*/false,
/*skip_target_swizzle=*/false,
/*skip_resize=*/false,
/*error_on_failure=*/false)
.getReplay();
// P2C map is required when using the IdModel-based analysis
const std::unordered_map<IterDomain*, IterDomain*>
p2c_map_no_forwarding = GpuLower::current()->hasIdModel()
? BestEffortReplay(
consumer->getLoopDomain(),
producer->getLoopDomain(),
PairwiseLogicalDomainMap(producer, consumer)
.mapProducerToConsumer(),
/*replay_forward_id_map=*/{},
/*target_forward_id_map=*/{},
/*skip_replay_swizzle=*/false,
/*skip_target_swizzle=*/false,
/*skip_resize=*/false,
/*error_on_failure=*/false)
.getReplay()
: std::unordered_map<IterDomain*, IterDomain*>{};

// At this point each parallel type that's present in the consumer or
// the producer will be present in their corresponding `_parallel_ids`
Expand Down Expand Up @@ -679,21 +683,16 @@ SyncMap::SyncMap(Fusion* fusion) {
continue;
}

// Check if the IdModel-based analysis matches with the
// existing analysis
bool requires_sync = true;

IterDomain* p_loop = nullptr;
IterDomain* c_loop = nullptr;

// Use the IdModel loop promotion when available. This is
// required for tensors with non-trivial loop domains
if (GpuLower::current()->hasIdModel()) {
if (producer_ptype == consumer_ptype) {
// If p_id and c_id are mapped in BestEffortReplay with
// no broadcast forwarding, they should not require any
// synchronization.
if (auto it = p2c_map_no_forwarding.find(p_id);
it != p2c_map_no_forwarding.end() && it->second == c_id) {
requires_sync = false;
continue;
} else {
// Even if not mapped in BestEffortReplay, inlining
// may effectively make the producer and consumer
Expand All @@ -704,56 +703,40 @@ SyncMap::SyncMap(Fusion* fusion) {
const auto& id_model = GpuLower::current()->idModel();
auto producer_loop_id =
indexing_utils::getLoopPromotion(p_id, id_model);
p_loop = producer_loop_id;
auto consumer_loop_id =
indexing_utils::getLoopPromotion(c_id, id_model);
c_loop = consumer_loop_id;
const auto& indexing_traveral_graph =
id_model.idGraph(IdMappingMode::ALMOSTEXACT);
id_model.idGraph(TensorIndexer::traversalGraphType());
if (indexing_traveral_graph.disjointValSets().strictAreMapped(
producer_loop_id, consumer_loop_id)) {
requires_sync = false;
continue;
}
}
}
}

// When the producer is parallelized, the producer and the
// consumer must use the same index with the same parallel
// type. Otherwise, a sync is required. This is not the case
// when this op is a parallel broadcast.
if (producer_parallel_bcast) {
// As long as they are permissively mapped using the same
// parallel type, no communication is required
if (producer_ptype == consumer_ptype &&
ca_map->areMapped(p_id, c_id, IdMappingMode::PERMISSIVE)) {
} else {
// When the producer is parallelized, the producer and the
// consumer must use the same index with the same parallel
// type. Otherwise, a sync is required. This is not the case
// when this op is a parallel broadcast.
if (producer_parallel_bcast) {
// As long as they are permissively mapped using the same
// parallel type, no communication is required
if (producer_ptype == consumer_ptype &&
ca_map->areMapped(p_id, c_id, IdMappingMode::PERMISSIVE)) {
continue;
}
// Can this happen?
NVF_ERROR(
!requires_sync,
expr->toString(),
p_id->toString(),
" (",
p_loop->toString(),
")",
", ",
c_id->toString(),
" (",
c_loop->toString(),
")");
continue;
false,
"Unexpected case. Producer: ",
producer->toString(),
", consumer: ",
consumer->toString());
}
// Can this happen?
NVF_ERROR(
false,
"Unexpected case. Producer: ",
producer->toString(),
", consumer: ",
consumer->toString());
}

if (producer_ptype == consumer_ptype) {
if (useSameIndex(producer, p_id, consumer, c_id, indexing_info)) {
NVF_ERROR(!requires_sync);
continue;
if (producer_ptype == consumer_ptype) {
if (useSameIndex(producer, p_id, consumer, c_id, indexing_info)) {
continue;
}
}
}

Expand Down
7 changes: 6 additions & 1 deletion csrc/id_model/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ class TensorIndexer {
// should not affect actual index exprs.
// Returns non-const reference because indexing may create new domains and
// need to update the graph.

static IdMappingMode traversalGraphType() {
return IdMappingMode::ALMOSTEXACT;
}

ValGraph& traversalGraph() const {
return id_model_.idGraph(IdMappingMode::ALMOSTEXACT);
return id_model_.idGraph(traversalGraphType());
}

// Traverse exprs and set allocation info for each tensor
Expand Down
4 changes: 1 addition & 3 deletions tests/cpp/test_gpu_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2701,11 +2701,9 @@ TEST_F(GpuViewTest, FusionMismatchingReshape) {
// Parallelize all tensors as [BIDx, TIDx]
schedule.merge(0);
schedule.split(0, 128);
#if 0
// TODO: sync analysis is not working yet

schedule.parallelize(0, ParallelType::BIDx);
schedule.parallelize(1, ParallelType::TIDx);
#endif

// Now, tv5 looks like:
//
Expand Down

0 comments on commit 72ae077

Please sign in to comment.