diff --git a/src/idtr.cpp b/src/idtr.cpp index 11869ea..f0da0d7 100644 --- a/src/idtr.cpp +++ b/src/idtr.cpp @@ -735,18 +735,28 @@ template class WaitPermute { SHARPY::rank_type cRank, SHARPY::rank_type nRanks, std::vector &&parts, std::vector &&axes, std::vector oGShape, ndarray &&input, - ndarray &&output, std::vector &&receiveBuffer, - std::vector &&receiveOffsets, + ndarray &&output, std::vector &&sendBuffer, + std::vector &&sendOffsets, std::vector &&sendSizes, + std::vector &&receiveBuffer, std::vector &&receiveOffsets, std::vector &&receiveSizes) : tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)), axes(std::move(axes)), oGShape(std::move(oGShape)), input(std::move(input)), output(std::move(output)), + sendBuffer(std::move(sendBuffer)), sendOffsets(std::move(sendOffsets)), + sendSizes(std::move(sendSizes)), receiveBuffer(std::move(receiveBuffer)), receiveOffsets(std::move(receiveOffsets)), receiveSizes(std::move(receiveSizes)) {} + // Only allow move + WaitPermute(const WaitPermute &) = delete; + WaitPermute &operator=(const WaitPermute &) = delete; + WaitPermute(WaitPermute &&) = default; + WaitPermute &operator=(WaitPermute &&) = default; + void operator()() { tc->wait(hdl); + std::vector> receiveRankBuffer(nRanks); for (size_t rank = 0; rank < nRanks; ++rank) { auto &rankBuffer = receiveRankBuffer[rank]; @@ -755,6 +765,7 @@ template class WaitPermute { receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]); } + // FIXME: very low efficiency, need to improve std::vector receiveRankBufferCount(nRanks, 0); input.globalIndices([&](const id &inputIndex) { id outputIndex = inputIndex.permute(axes); @@ -777,6 +788,9 @@ template class WaitPermute { std::vector oGShape; ndarray input; ndarray output; + std::vector sendBuffer; + std::vector sendOffsets; + std::vector sendSizes; std::vector receiveBuffer; std::vector receiveOffsets; std::vector receiveSizes; @@ -870,6 +884,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype, for (auto i = 0ul; i < nRanks; ++i) { dspl[i] = 4 * i; } + tc->gather(parts.data(), counts.data(), dspl.data(), SHARPY::INT64, SHARPY::REPLICATED); @@ -919,10 +934,12 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype, sendOffsets.data(), sharpytype, receiveBuffer.data(), receiveSizes.data(), receiveOffsets.data()); - auto wait = WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), - std::move(axes), std::move(oGShape), std::move(input), - std::move(output), std::move(receiveBuffer), - std::move(receiveOffsets), std::move(receiveSizes)); + auto wait = + WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), std::move(axes), + std::move(oGShape), std::move(input), std::move(output), + std::move(sendBuffer), std::move(sendOffsets), + std::move(sendSizes), std::move(receiveBuffer), + std::move(receiveOffsets), std::move(receiveSizes)); assert(parts.empty() && axes.empty() && receiveBuffer.empty() && receiveOffsets.empty() && receiveSizes.empty());