Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden assertBuffersHaveSameSize to check shapes. #3531

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,13 @@ void HostIrEvaluator::handle(Wait* wait) {

namespace {

void allConsumerValsOfHelper(
Val* val,
std::unordered_set<Val*>& visisted_vals) {
if (visisted_vals.find(val) != visisted_vals.end()) {
void allConsumerValsOfHelper(Val* val, std::unordered_set<Val*>& visited_vals) {
if (visited_vals.find(val) != visited_vals.end()) {
return;
}
visisted_vals.insert(val);
visited_vals.insert(val);
for (Val* consumer : ir_utils::consumerValsOf(val)) {
allConsumerValsOfHelper(consumer, visisted_vals);
allConsumerValsOfHelper(consumer, visited_vals);
}
}

Expand Down
9 changes: 6 additions & 3 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ void assertBuffersHaveSameSize(
if (bufs1.empty() && bufs2.empty()) {
return;
}
const auto numel = (bufs1.empty() ? bufs2 : bufs1).at(0).numel();
const auto shape = (bufs1.empty() ? bufs2 : bufs1).at(0).sizes();
for (const auto& bufs : {bufs1, bufs2}) {
for (const auto& buf : bufs) {
NVF_ERROR(
buf.numel() == numel,
"all buffers must have the same number of elements");
buf.sizes() == shape,
"all buffers must have the same shape, but got: ",
buf.sizes(),
" vs ",
shape);
}
}
}
Expand Down
30 changes: 24 additions & 6 deletions tests/cpp/test_multidevice_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class OverlapTest : public MultiDeviceTest {

void validate() {
auto tc_expected = getExpectedResult();
auto tc_cpu = tc_.to(torch::kCPU);
auto tc_cpu = tc_.cpu();
EXPECT_TRUE(tc_cpu.allclose(tc_expected, 1e-1, 1e-1))
<< "Unexpected results, obtained:" << tc_cpu
<< "\n expected: " << tc_expected;
Expand Down Expand Up @@ -837,18 +837,19 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) {
IrBuilder::create<hir::Stream>(stream_index));

TensorView* tva_j = select(tva, 0, j);
TensorView* tva_j_unsqueezed = unsqueeze(tva_j, 0);
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);

// Setting the DeviceMesh of the communication's I/O is artificial but
// required at this point
DeviceMesh full_mesh(all_devices_);
tva_allgathered_j->setDeviceMesh(full_mesh);
tva_j->setDeviceMesh(full_mesh);
tva_j_unsqueezed->setDeviceMesh(full_mesh);

auto* communication = IrBuilder::create<Communication>(
CommunicationType::Allgather,
/*out=*/tva_allgathered_j,
/*in=*/tva_j,
/*in=*/tva_j_unsqueezed,
/*team=*/all_devices_);
auto* wait = IrBuilder::create<hir::Wait>(communication);

Expand All @@ -864,6 +865,7 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) {
std::vector<Expr*> loop_body = {
set_stream,
tva_j->definition(),
tva_j_unsqueezed->definition(),
tva_allgathered_j->definition(),
communication,
wait,
Expand Down Expand Up @@ -899,9 +901,25 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) {
for_loop_stream->body().push_back(sync_stream);
hic->pushBackTopLevelExprs(for_loop_stream);

// The following line is artificial but necessary to make
// tva_j->isProducerOf(tvc_j) == true
hic->addOutput(tvc_j);
// The following line is artificial but necessary to make tva_j_unsqueeze a
// consumer of tva_j.
//
// HostIrEvaluator::handle(ForLoop*) relies on `Val::uses()` to find all
// **transitive** consumers of the loop index `j`. `tva_j_unsqueezed` is a
// bit special among all transitive consumers of `j`. It doesn't use `j`
// directly but uses `tva_j` which is a TensorView. TensorView's uses are
// built lazily by Fusion::resetTvUses. For efficiency, Fusion::resetTvUses
// only fix TensorViews that can reach outputs. Therefore, we add
// tva_j_unsqueezed as an output. Other TensorViews don't need this
// treatmenet because they are direct users of `j`, a scalar whose uses are
// built eagerly upon registration.
//
// We could have added `tvc_j` instead as an output, which transitively
// consumes `tva_j_unsqueezed`. However, `tvc_j` has two definitions, a Select
// and a MatmulOp, and StmtSort::getExprs only traverse via the first
// registered definition (i.e. the Select). This sounds like a bug -- I wonder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StmtSort and other stuff in iter_visiter.h assume the SSA property of Fusion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean TVs in a kir::Kernel (also non-SSA) get wrong Val::uses(), which should be avoided using?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be actually used, but non-SSA definitions in the Kernel IR are pretty limited so far, so we may not encounter any problems. But in general, it isn't a well ironed out use scenario.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I'll have to revisit how we evaluate for-loops. One potential approach is to only invalidate loop-index-dependent scalars and let TensorView ops in the loop body run unconditionally.

// how nvFuser resets the TensorView uses of a kir::Kernel, also non-SSA.
hic->addOutput(tva_j_unsqueezed);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The introduction of tva_j_unsqueezed triggered a weird problem that @samnordmann is probably aware of. I added more explanation and wonder what @naoyam think about this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make TensorView live even if no output depends on it for HostIR. Not sure if that would solve the issue, though, as I'm still not entirely clear what the issue is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am aware we artificially need to add the matmul's output as a fusion output to fix the data dependency, that is why tvc_j was added in the first place. However, I was not aware of the other bug you're mentioning -- that we only traverse the first registered producing Expr.

Would the program break if you only let hic->addOutput(tvc_j); ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the program break if you only let hic->addOutput(tvc_j);?

Yes as I commented at https://github.com/NVIDIA/Fuser/pull/3531/files#diff-30df6421558f87ef0024b01f11752c35d3d68b80a9e6e0ec0fd49de535acb91aR917

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but I am not sure to fully understand the reason why it breaks. Even if the visitor only traverses through the first definition, i.e., the SelectOp, then tvc_j should still be invalidated because the SelectOp consumes the index j

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, tvcj will be invalidated but tvaj unsqueeze won't be. As a result it holds always hold the first iteration value


hir::HostIrEvaluator hie(std::move(hic), communicator_);

Expand Down
Loading