Skip to content

Commit

Permalink
Merge pull request #526 from m-a-d-n-e-s-s/524-moldft-with-2-processe…
Browse files Browse the repository at this point in the history
…s-aborts

amends batching changes to WorldGopInterface::concat0
  • Loading branch information
fbischoff authored Feb 15, 2024
2 parents 31de6ea + b82563d commit 14435f3
Showing 1 changed file with 93 additions and 44 deletions.
137 changes: 93 additions & 44 deletions src/madness/world/worldgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -906,45 +906,85 @@ namespace madness {
std::vector<T> concat0(const std::vector<T>& v, size_t bufsz=1024*1024) {
MADNESS_ASSERT(bufsz <= std::numeric_limits<int>::max());

SafeMPI::Request req0, req1;
ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
int child0_nbatch = 0, child1_nbatch = 0;

auto buf0 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);
auto buf1 = std::unique_ptr<std::byte[]>(new std::byte[bufsz]);

// transfer data in chunks at most this large
const int batch_size = static_cast<int>(std::min(static_cast<size_t>(max_reducebcast_msg_size()),bufsz));
std::deque<Tag> tags; // stores tags used to send each batch

auto batched_receives = [&,this](size_t buf_offset) {
MADNESS_ASSERT(batch_size <= bufsz);
Tag gsum_tag = world_.mpi.unique_tag();
tags.push_back(gsum_tag);

if (child0 != -1)
req0 = world_.mpi.Irecv(buf0.get() + buf_offset,
bufsz - batch_size, MPI_BYTE, child0,
gsum_tag);
if (child1 != -1)
req1 = world_.mpi.Irecv(buf1.get() + buf_offset,
bufsz - batch_size, MPI_BYTE, child1,
gsum_tag);

if (child0 != -1) {
World::await(req0);
}
if (child1 != -1) {
World::await(req1);
}
};

// receive data in batches
// precompute max # of tags any node ... will need, and allocate them on every node to avoid tag counter divergence
const int max_nbatch = bufsz / batch_size;
// one tag is reserved for sending the number of messages to expect and the size of the last message
const int max_ntags = max_nbatch + 1;
MADNESS_ASSERT(max_nbatch < world_.mpi.unique_tag_period());
std::vector<Tag> tags; // stores tags used to send each batch
tags.reserve(max_nbatch);
for(int t=0; t<max_ntags; ++t) tags.push_back(world_.mpi.unique_tag());

if (child0 != -1 || child1 != -1) {
// receive # of batches

auto receive_nbatch = [&,this]() {
if (child0 != -1) {
world_.mpi.Recv(&child0_nbatch, 1, MPI_INT, child0,
tags[0]);
}
if (child1 != -1) {
world_.mpi.Recv(&child1_nbatch, 1, MPI_INT, child1,
tags[0]);
}
};

receive_nbatch();

// receive data in batches

auto receive_batch = [&,this](int batch, size_t buf_offset) {
SafeMPI::Request req0, req1;
if (child0 != -1 && batch < child0_nbatch) {
int msg_size = batch_size;
// if last batch, receive # of bytes to expect
if (batch + 1 == child0_nbatch) {
auto req0 = world_.mpi.Irecv(
&msg_size, 1, MPI_INT, child0, tags[0]);
World::await(req0);
}

req0 = world_.mpi.Irecv(buf0.get() + buf_offset,
msg_size, MPI_BYTE, child0,
tags[batch + 1]);
}
if (child1 != -1 && batch < child1_nbatch) {
int msg_size = batch_size;
// if last batch, receive # of bytes to expect
if (batch + 1 == child1_nbatch) {
auto req1 = world_.mpi.Irecv(
&msg_size, 1, MPI_INT, child0, tags[0]);
World::await(req1);
}
req1 = world_.mpi.Irecv(buf1.get() + buf_offset,
bufsz - batch_size, MPI_BYTE, child1,
tags[batch + 1]);
}

if (child0 != -1 && batch < child0_nbatch) {
World::await(req0);
}
if (child1 != -1 && batch < child1_nbatch) {
World::await(req1);
}
};

size_t buf_offset = 0;
int batch = 0;
while (buf_offset < bufsz) {
batched_receives(buf_offset);
receive_batch(batch, buf_offset);
buf_offset += batch_size;
buf_offset = std::min(buf_offset, bufsz);
++batch;
}
}

Expand All @@ -968,27 +1008,36 @@ namespace madness {
ar & left;
const auto total_nbytes_to_send = ar.size();

auto batched_send = [&,this](size_t buf_offset) {
MADNESS_ASSERT(batch_size <= bufsz);
Tag gsum_tag;
if (tags.empty()) {
gsum_tag = world_.mpi.unique_tag();
} else {
gsum_tag = tags.front();
tags.pop_front();
}

const auto nbytes_to_send = static_cast<int>(std::min(static_cast<size_t>(batch_size), total_nbytes_to_send - buf_offset));
req0 = world_.mpi.Isend(buf0.get() + buf_offset, nbytes_to_send, MPI_BYTE, parent,
gsum_tag);
World::await(req0);
};

size_t buf_offset = 0;
int batch = 0;
while (buf_offset < bufsz) {
batched_send(buf_offset);

// send nbatches to expect
const int nbatch = (total_nbytes_to_send + batch_size - 1) / batch_size;
world_.mpi.Send(&nbatch, 1, MPI_INT, parent,
tags[0]);

// send data in batches
auto send_batch = [&,this](int batch, size_t buf_offset) {
const int nbytes_to_send = static_cast<int>(
std::min(static_cast<size_t>(batch_size),
total_nbytes_to_send - buf_offset));
// if last batch, send # of bytes to expect
if (batch + 1 == nbatch) {
auto req0 = world_.mpi.Isend(
&nbytes_to_send, 1, MPI_INT, parent, tags[0]);
World::await(req0);
}
auto req0 =
world_.mpi.Isend(buf0.get() + buf_offset, nbytes_to_send,
MPI_BYTE, parent, tags[batch + 1]);
World::await(req0);
};

send_batch(batch, buf_offset);
buf_offset += batch_size;
buf_offset = std::min(buf_offset, bufsz);
++batch;
}
}

Expand Down

0 comments on commit 14435f3

Please sign in to comment.