Skip to content

Commit

Permalink
Merge pull request #553 from EricaCMitchell/pr_issue_551
Browse files Browse the repository at this point in the history
Fix to concat0
  • Loading branch information
evaleev authored Oct 15, 2024
2 parents 93a9a5c + e740317 commit f76b84a
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/madness/world/worldgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,6 @@ namespace madness {
};
using sptr_t = std::unique_ptr<T[], free_dtor>;

sptr_t buf0;
auto aligned_buf_alloc = [&]() -> T* {
// posix_memalign requires alignment to be an integer multiple of sizeof(void*)!! so ensure that
const std::size_t alignment =
Expand All @@ -813,9 +812,11 @@ namespace madness {
}
return static_cast<T *>(ptr);
#else
return static_cast<T*>(std::aligned_alloc(alignment, buf_size));
return static_cast<T *>(std::aligned_alloc(alignment, buf_size));
#endif
};

sptr_t buf0;
if (child0 != -1)
buf0 = sptr_t(aligned_buf_alloc(),
free_dtor{});
Expand Down Expand Up @@ -951,6 +952,8 @@ namespace madness {
template <typename T>
std::vector<T> concat0(const std::vector<T>& v, size_t bufsz=1024*1024) {
MADNESS_ASSERT(bufsz <= std::numeric_limits<int>::max());
// bufsz must be multiple of alignment!!! so ensure that
bufsz = ((bufsz + sizeof(void*) - 1) / sizeof(void*)) * sizeof(void*);

ProcessID parent, child0, child1;
world_.mpi.binary_tree_info(0, parent, child0, child1);
Expand All @@ -964,20 +967,16 @@ namespace madness {
};
using sptr_t = std::unique_ptr<std::byte[], free_dtor>;

sptr_t buf0;
if (child0 != -1)
buf0 = sptr_t(static_cast<std::byte *>(std::aligned_alloc(
std::alignment_of_v<T>, bufsz)),
free_dtor{});
sptr_t buf1;
if (child1 != -1)
buf1 = sptr_t(static_cast<std::byte *>(std::aligned_alloc(
std::alignment_of_v<T>, bufsz)),
free_dtor{});
auto buf0 = sptr_t(static_cast<std::byte *>(
std::aligned_alloc(sizeof(void *), bufsz)),
free_dtor{});
auto buf1 = sptr_t(static_cast<std::byte *>(
std::aligned_alloc(sizeof(void *), bufsz)),
free_dtor{});

// transfer data in chunks at most this large
const int batch_size = static_cast<int>(
std::min(static_cast<size_t>(max_reducebcast_msg_size()),bufsz));
std::min(static_cast<size_t>(max_reducebcast_msg_size()), bufsz));

// precompute max # of tags any node ... will need, and allocate them on every node to avoid tag counter divergence
const int max_nbatch = bufsz / batch_size;
Expand Down

0 comments on commit f76b84a

Please sign in to comment.