diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 38906669367..1f8dc97bbfa 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -378,15 +378,13 @@ void HostIrEvaluator::handle(Wait* wait) { namespace { -void allConsumerValsOfHelper( - Val* val, - std::unordered_set& visisted_vals) { - if (visisted_vals.find(val) != visisted_vals.end()) { +void allConsumerValsOfHelper(Val* val, std::unordered_set& visited_vals) { + if (visited_vals.find(val) != visited_vals.end()) { return; } - visisted_vals.insert(val); + visited_vals.insert(val); for (Val* consumer : ir_utils::consumerValsOf(val)) { - allConsumerValsOfHelper(consumer, visisted_vals); + allConsumerValsOfHelper(consumer, visited_vals); } } diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 29ef6995969..522b755b96d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -66,12 +66,15 @@ void assertBuffersHaveSameSize( if (bufs1.empty() && bufs2.empty()) { return; } - const auto numel = (bufs1.empty() ? bufs2 : bufs1).at(0).numel(); + const auto shape = (bufs1.empty() ? bufs2 : bufs1).at(0).sizes(); for (const auto& bufs : {bufs1, bufs2}) { for (const auto& buf : bufs) { NVF_ERROR( - buf.numel() == numel, - "all buffers must have the same number of elements"); + buf.sizes() == shape, + "all buffers must have the same shape, but got: ", + buf.sizes(), + " vs ", + shape); } } } diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index b9ca9ee5a91..e371a53dd4d 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -164,7 +164,7 @@ class OverlapTest : public MultiDeviceTest { void validate() { auto tc_expected = getExpectedResult(); - auto tc_cpu = tc_.to(torch::kCPU); + auto tc_cpu = tc_.cpu(); EXPECT_TRUE(tc_cpu.allclose(tc_expected, 1e-1, 1e-1)) << "Unexpected results, obtained:" << tc_cpu << "\n expected: " << tc_expected; @@ -837,18 +837,19 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) { IrBuilder::create(stream_index)); TensorView* tva_j = select(tva, 0, j); + TensorView* tva_j_unsqueezed = unsqueeze(tva_j, 0); TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); // Setting the DeviceMesh of the communication's I/O is artificial but // required at this point DeviceMesh full_mesh(all_devices_); tva_allgathered_j->setDeviceMesh(full_mesh); - tva_j->setDeviceMesh(full_mesh); + tva_j_unsqueezed->setDeviceMesh(full_mesh); auto* communication = IrBuilder::create( CommunicationType::Allgather, /*out=*/tva_allgathered_j, - /*in=*/tva_j, + /*in=*/tva_j_unsqueezed, /*team=*/all_devices_); auto* wait = IrBuilder::create(communication); @@ -864,6 +865,7 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) { std::vector loop_body = { set_stream, tva_j->definition(), + tva_j_unsqueezed->definition(), tva_allgathered_j->definition(), communication, wait, @@ -899,9 +901,25 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) { for_loop_stream->body().push_back(sync_stream); hic->pushBackTopLevelExprs(for_loop_stream); - // The following line is artificial but necessary to make - // tva_j->isProducerOf(tvc_j) == true - hic->addOutput(tvc_j); + // The following line is artificial but necessary to make tva_j_unsqueeze a + // consumer of tva_j. + // + // HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all + // **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a + // bit special among all transitive consumers of `j`. It doesn't use `j` + // directly but uses `tva_j` which is a TensorView. TensorView's uses are + // built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses + // only fix TensorViews that can reach outputs. Therefore, we add + // tva_j_unsqueezed as an output. Other TensorViews don't need this + // treatmenet because they are direct users of `j`, a scalar whose uses are + // built eagerly upon registration. + // + // We could have added `tvc_j` instead as an output, which transitively + // consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select + // and a MatmulOp, and StmtSort::getExprs only traverse via the first + // registered definition (i.e. the Select). This sounds like a bug -- I wonder + // how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA. + hic->addOutput(tva_j_unsqueezed); hir::HostIrEvaluator hie(std::move(hic), communicator_);